Skip to content

Commit

Permalink
Update DocIterator (#2060)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Update DocIterator
Sort children of AndIterator by EstimateIterateCost instead of DocFreq

Issue link:#1862

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
  • Loading branch information
yangzq50 authored Oct 17, 2024
1 parent c452460 commit ae34fad
Show file tree
Hide file tree
Showing 16 changed files with 115 additions and 259 deletions.
11 changes: 6 additions & 5 deletions src/storage/invertedindex/search/and_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ import infinity_exception;
namespace infinity {

AndIterator::AndIterator(Vector<UniquePtr<DocIterator>> iterators) : MultiDocIterator(std::move(iterators)) {
std::sort(children_.begin(), children_.end(), [](const auto &lhs, const auto &rhs) { return lhs->GetDF() < rhs->GetDF(); });
// init df
doc_freq_ = std::numeric_limits<u32>::max();
std::sort(children_.begin(), children_.end(), [](const auto &lhs, const auto &rhs) {
return lhs->GetEstimateIterateCost() < rhs->GetEstimateIterateCost();
});
bm25_score_upper_bound_ = 0.0f;
estimate_iterate_cost_ = {};
for (SizeT i = 0; i < children_.size(); i++) {
const auto &it = children_[i];
doc_freq_ = std::min(doc_freq_, it->GetDF());
bm25_score_upper_bound_ += children_[i]->BM25ScoreUpperBound();
estimate_iterate_cost_ = std::min(estimate_iterate_cost_, it->GetEstimateIterateCost());
// for minimum_should_match parameter
switch (it->GetType()) {
case DocIteratorType::kTermDocIterator:
Expand All @@ -50,7 +51,7 @@ AndIterator::AndIterator(Vector<UniquePtr<DocIterator>> iterators) : MultiDocIte
}
}

bool AndIterator::Next(RowID doc_id) {
bool AndIterator::Next(const RowID doc_id) {
assert(doc_id != INVALID_ROWID);
if (doc_id_ != INVALID_ROWID && doc_id_ >= doc_id)
return true;
Expand Down
9 changes: 4 additions & 5 deletions src/storage/invertedindex/search/and_not_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ import internal_types;
namespace infinity {

AndNotIterator::AndNotIterator(Vector<UniquePtr<DocIterator>> iterators) : MultiDocIterator(std::move(iterators)) {
std::sort(children_.begin() + 1, children_.end(), [](const auto &lhs, const auto &rhs) { return lhs->GetDF() < rhs->GetDF(); });
// initialize doc_id_ to first doc
Next(0);
// init df
doc_freq_ = children_[0]->GetDF();
std::sort(children_.begin() + 1, children_.end(), [](const auto &lhs, const auto &rhs) {
return lhs->GetEstimateIterateCost() < rhs->GetEstimateIterateCost();
});
bm25_score_upper_bound_ = children_[0]->BM25ScoreUpperBound();
estimate_iterate_cost_ = children_[0]->GetEstimateIterateCost();
}

bool AndNotIterator::Next(RowID doc_id) {
Expand Down
113 changes: 0 additions & 113 deletions src/storage/invertedindex/search/blockmax_maxscore_iterator.cpp

This file was deleted.

84 changes: 0 additions & 84 deletions src/storage/invertedindex/search/blockmax_maxscore_iterator.cppm

This file was deleted.

7 changes: 5 additions & 2 deletions src/storage/invertedindex/search/blockmax_wand_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,15 @@ BlockMaxWandIterator::~BlockMaxWandIterator() {
BlockMaxWandIterator::BlockMaxWandIterator(Vector<UniquePtr<DocIterator>> &&iterators)
: MultiDocIterator(std::move(iterators)), pivot_(sorted_iterators_.size()) {
bm25_score_upper_bound_ = 0.0f;
estimate_iterate_cost_ = {};
SizeT num_iterators = children_.size();
for (SizeT i = 0; i < num_iterators; i++){
TermDocIterator *tdi = dynamic_cast<TermDocIterator *>(children_[i].get());
if (tdi == nullptr)
continue;
if (tdi == nullptr) {
UnrecoverableError("BMW only supports TermDocIterator");
}
bm25_score_upper_bound_ += tdi->BM25ScoreUpperBound();
estimate_iterate_cost_ += tdi->GetEstimateIterateCost();
sorted_iterators_.push_back(tdi);
}
next_sum_score_bm_low_cnt_dist_.resize(100, 0);
Expand Down
17 changes: 15 additions & 2 deletions src/storage/invertedindex/search/doc_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ export enum class DocIteratorType : u8 {
kFilterIterator,
};

export struct DocIteratorEstimateIterateCost {
u32 priority_ = 0;
u32 estimate_cost_ = 0;
friend auto operator<=>(const DocIteratorEstimateIterateCost &, const DocIteratorEstimateIterateCost &) = default;
auto &operator+=(DocIteratorEstimateIterateCost rhs) {
if (priority_ < rhs.priority_) {
std::swap(*this, rhs);
}
estimate_cost_ += (rhs.estimate_cost_ >> 3 * (priority_ - rhs.priority_));
return *this;
}
};

export class DocIterator {
public:
DocIterator() {}
Expand All @@ -45,7 +58,7 @@ public:

RowID DocID() const { return doc_id_; }

inline u32 GetDF() const { return doc_freq_; }
inline DocIteratorEstimateIterateCost GetEstimateIterateCost() const { return estimate_iterate_cost_; }

// Update doc_id_ to one larger than previous one.
// If has_blockmax is true, it ensures its BM25 score be larger than current threshold.
Expand Down Expand Up @@ -81,8 +94,8 @@ public:

protected:
RowID doc_id_{INVALID_ROWID};
u32 doc_freq_ = 0;
float threshold_ = 0.0f;
float bm25_score_upper_bound_ = std::numeric_limits<float>::max();
DocIteratorEstimateIterateCost estimate_iterate_cost_ = {};
};
} // namespace infinity
24 changes: 20 additions & 4 deletions src/storage/invertedindex/search/minimum_should_match_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,27 @@ MinimumShouldMatchIterator::MinimumShouldMatchIterator(Vector<UniquePtr<DocItera
}
tail_heap_.resize(minimum_should_match_ - 1u);
bm25_score_cache_docid_ = INVALID_ROWID;
bm25_score_upper_bound_ = 0.0f;
estimate_iterate_cost_ = {};
for (const auto &child : children_) {
bm25_score_upper_bound_ += child->BM25ScoreUpperBound();
estimate_iterate_cost_ += child->GetEstimateIterateCost();
}
}

MinimumShouldMatchIterator::~MinimumShouldMatchIterator() {}

void MinimumShouldMatchIterator::UpdateScoreThreshold(float threshold) { UnrecoverableError("Unreachable code"); }
void MinimumShouldMatchIterator::UpdateScoreThreshold(const float threshold) {
if (threshold <= threshold_) {
return;
}
threshold_ = threshold;
const float base_threshold = threshold - BM25ScoreUpperBound();
for (const auto &it : children_) {
const float new_threshold = std::max(0.0f, base_threshold + it->BM25ScoreUpperBound());
it->UpdateScoreThreshold(new_threshold);
}
}

bool MinimumShouldMatchIterator::Next(RowID doc_id) {
if (doc_id_ == INVALID_ROWID) {
Expand Down Expand Up @@ -160,13 +176,13 @@ u32 MinimumShouldMatchIterator::PopFromHeadHeap() {
}

Pair<bool, u32> MinimumShouldMatchIterator::PushToTailHeap(const u32 idx) {
auto comp = [&](const u32 lhs, const u32 rhs) { return children_[lhs]->GetDF() > children_[rhs]->GetDF(); };
auto comp = [&](const u32 lhs, const u32 rhs) { return children_[lhs]->GetEstimateIterateCost() > children_[rhs]->GetEstimateIterateCost(); };
if (tail_size_ < tail_heap_.size()) {
tail_heap_[tail_size_++] = idx;
std::push_heap(tail_heap_.begin(), tail_heap_.begin() + tail_size_, comp);
return {false, std::numeric_limits<u32>::max()};
}
if (children_[idx]->GetDF() <= children_[tail_heap_.front()]->GetDF()) {
if (children_[idx]->GetEstimateIterateCost() <= children_[tail_heap_.front()]->GetEstimateIterateCost()) {
return {true, idx};
}
const auto result = tail_heap_.front();
Expand All @@ -179,7 +195,7 @@ Pair<bool, u32> MinimumShouldMatchIterator::PushToTailHeap(const u32 idx) {
u32 MinimumShouldMatchIterator::PopFromTailHeap() {
assert(tail_size_ > 0);
std::pop_heap(tail_heap_.begin(), tail_heap_.begin() + tail_size_, [&](const u32 lhs, const u32 rhs) {
return children_[lhs]->GetDF() > children_[rhs]->GetDF();
return children_[lhs]->GetEstimateIterateCost() > children_[rhs]->GetEstimateIterateCost();
});
return tail_heap_[--tail_size_];
}
Expand Down
5 changes: 4 additions & 1 deletion src/storage/invertedindex/search/multi_doc_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ void MultiDocIterator::PrintTree(std::ostream &os, const String &prefix, bool is
os << prefix;
os << (is_final ? "└──" : "├──");
os << Name();
os << " (doc_freq: " << GetDF() << ")";
{
auto [level, cost] = GetEstimateIterateCost();
os << " (estimate_iterate_cost: " << level << ", " << cost << ")";
}
os << " (bm25_score_upper_bound: " << BM25ScoreUpperBound() << ")";
os << " (threshold: " << Threshold() << ")";
os << " (children count: " << children_.size() << ")";
Expand Down
Loading

0 comments on commit ae34fad

Please sign in to comment.