Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,10 @@ void QueryHelper::collect_range(const IndexQueryContextPtr& context,
}
}

bool QueryHelper::is_simple_phrase(const std::vector<TermInfo>& term_infos) {
return std::ranges::all_of(term_infos,
[](const auto& term_info) { return term_info.is_single_term(); });
}

#include "common/compile_check_end.h"
} // namespace doris::segment_v2
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class QueryHelper {
const DocRange& doc_range);
static void collect_range(const IndexQueryContextPtr& context, const SimilarityPtr& similarity,
const DocRange& doc_range);

static bool is_simple_phrase(const std::vector<TermInfo>& term_infos);
};

#include "common/compile_check_end.h"
Expand Down
106 changes: 84 additions & 22 deletions be/src/olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,36 @@ class DocSet {
virtual ~DocSet() = default;

virtual uint32_t advance() {
throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"advance() method not implemented in base DocSet class");
throw Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"advance() method not implemented in base DocSet class");
}

virtual uint32_t seek(uint32_t target) {
throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"seek() method not implemented in base DocSet class");
throw Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"seek() method not implemented in base DocSet class");
}

virtual uint32_t doc() const {
throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"doc() method not implemented in base DocSet class");
throw Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"doc() method not implemented in base DocSet class");
}

virtual uint32_t size_hint() const {
throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"size_hint() method not implemented in base DocSet class");
throw Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"size_hint() method not implemented in base DocSet class");
}

virtual uint32_t freq() const {
throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"freq() method not implemented in base DocSet class");
throw Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"freq() method not implemented in base DocSet class");
}

virtual uint32_t norm() const {
throw doris::Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"norm() method not implemented in base DocSet class");
throw Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR,
"norm() method not implemented in base DocSet class");
}
};
using DocSetPtr = std::shared_ptr<DocSet>;

class MockDocSet : public DocSet {
public:
Expand All @@ -77,53 +78,114 @@ class MockDocSet : public DocSet {
}
}

uint32_t advance() override {
MockDocSet(std::vector<uint32_t> docs, std::map<uint32_t, std::vector<uint32_t>> doc_positions,
uint32_t size_hint_val = 0, uint32_t norm_val = 1)
: _docs(std::move(docs)),
_doc_positions(std::move(doc_positions)),
_size_hint_val(size_hint_val),
_norm_val(norm_val) {
if (_docs.empty()) {
_current_doc = TERMINATED;
} else {
std::ranges::sort(_docs.begin(), _docs.end());
_current_doc = _docs[0];
}
if (_size_hint_val == 0) {
_size_hint_val = static_cast<uint32_t>(_docs.size());
}
}

// Basic TermIterator-style interface (foundation methods)
bool next() {
if (_docs.empty() || _index >= _docs.size()) {
_current_doc = TERMINATED;
return TERMINATED;
return false;
}
++_index;
if (_index >= _docs.size()) {
_current_doc = TERMINATED;
return TERMINATED;
return false;
}
_current_doc = _docs[_index];
return _current_doc;
return true;
}

uint32_t seek(uint32_t target) override {
bool skipTo(uint32_t target) {
if (_docs.empty() || _index >= _docs.size()) {
_current_doc = TERMINATED;
return TERMINATED;
return false;
}
if (_current_doc >= target) {
return _current_doc;
return true;
}
auto it = std::lower_bound(_docs.begin() + _index, _docs.end(), target);
if (it == _docs.end()) {
_index = _docs.size();
_current_doc = TERMINATED;
return TERMINATED;
return false;
}
_index = static_cast<size_t>(it - _docs.begin());
_current_doc = *it;
return true;
}

uint32_t docFreq() const { return _size_hint_val; }

// DocSet virtual interface (built on top of basic methods)
uint32_t advance() override {
next();
return _current_doc;
}

uint32_t seek(uint32_t target) override {
skipTo(target);
return _current_doc;
}

uint32_t doc() const override { return _current_doc; }

uint32_t size_hint() const override { return _size_hint_val; }
uint32_t size_hint() const override { return docFreq(); }

uint32_t norm() const override { return _norm_val; }

uint32_t freq() const override {
if (_current_doc == TERMINATED) {
return 0;
}
auto it = _doc_positions.find(_current_doc);
if (it != _doc_positions.end()) {
return static_cast<uint32_t>(it->second.size());
}
return 1;
}

void append_positions_with_offset(uint32_t offset, std::vector<uint32_t>& output) {
if (_current_doc == TERMINATED) {
return;
}
auto it = _doc_positions.find(_current_doc);
if (it != _doc_positions.end()) {
size_t prev_size = output.size();
output.reserve(prev_size + it->second.size());
for (uint32_t pos : it->second) {
output.push_back(offset + pos);
}
}
}

void positions_with_offset(uint32_t offset, std::vector<uint32_t>& output) {
output.clear();
append_positions_with_offset(offset, output);
}

private:
std::vector<uint32_t> _docs;
std::map<uint32_t, std::vector<uint32_t>> _doc_positions;
size_t _index = 0;
uint32_t _current_doc = TERMINATED;
uint32_t _size_hint_val = 0;
uint32_t _norm_val = 1;
};

using MockDocSetPtr = std::shared_ptr<MockDocSet>;

} // namespace doris::segment_v2::inverted_index::query_v2
Original file line number Diff line number Diff line change
Expand Up @@ -147,29 +147,17 @@ Intersection<TDocSet, TOtherDocSet>::docset_mut_specialized(size_t ord) {
}
}

template class Intersection<PositionPostingsWithOffsetPtr, PositionPostingsWithOffsetPtr>;
template class Intersection<MockDocSetPtr, MockDocSetPtr>;

// create
template std::enable_if_t<
std::is_same_v<PositionPostingsWithOffsetPtr, PositionPostingsWithOffsetPtr>,
IntersectionPtr<PositionPostingsWithOffsetPtr, PositionPostingsWithOffsetPtr>>
Intersection<PositionPostingsWithOffsetPtr, PositionPostingsWithOffsetPtr>::create<
PositionPostingsWithOffsetPtr>(std::vector<PositionPostingsWithOffsetPtr>& docsets);

template std::enable_if_t<std::is_same_v<MockDocSetPtr, MockDocSetPtr>,
IntersectionPtr<MockDocSetPtr, MockDocSetPtr>>
Intersection<MockDocSetPtr, MockDocSetPtr>::create<MockDocSetPtr>(
std::vector<MockDocSetPtr>& docsets);

// docset_mut_specialized
template std::enable_if_t<
std::is_same_v<PositionPostingsWithOffsetPtr, PositionPostingsWithOffsetPtr>,
PositionPostingsWithOffsetPtr&>
Intersection<PositionPostingsWithOffsetPtr, PositionPostingsWithOffsetPtr>::docset_mut_specialized<
PositionPostingsWithOffsetPtr>(size_t ord);

template std::enable_if_t<std::is_same_v<MockDocSetPtr, MockDocSetPtr>, MockDocSetPtr&>
Intersection<MockDocSetPtr, MockDocSetPtr>::docset_mut_specialized<MockDocSetPtr>(size_t ord);
#define INSTANTIATE_INTERSECTION(T) \
template class Intersection<T, T>; \
template std::enable_if_t<std::is_same_v<T, T>, IntersectionPtr<T, T>> \
Intersection<T, T>::create<T>(std::vector<T> & docsets); \
template std::enable_if_t<std::is_same_v<T, T>, T&> \
Intersection<T, T>::docset_mut_specialized<T>(size_t ord);

INSTANTIATE_INTERSECTION(std::shared_ptr<PostingsWithOffset<PostingsPtr>>)
INSTANTIATE_INTERSECTION(std::shared_ptr<PostingsWithOffset<PositionPostingsPtr>>)
INSTANTIATE_INTERSECTION(MockDocSetPtr)

#undef INSTANTIATE_INTERSECTION

} // namespace doris::segment_v2::inverted_index::query_v2
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// be/src/olap/rowset/segment_v2/inverted_index/query_v2/nullable_scorer.h

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <glog/logging.h>

#include <memory>

#include "common/exception.h"
#include "olap/rowset/segment_v2/inverted_index/query_v2/null_bitmap_fetcher.h"
#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h"
#include "roaring/roaring.hh"

namespace doris::segment_v2::inverted_index::query_v2 {

template <typename ScorerPtrT = ScorerPtr>
class NullableScorer : public Scorer {
public:
NullableScorer(ScorerPtrT inner_scorer, std::shared_ptr<roaring::Roaring> null_bitmap)
: _inner_scorer(std::move(inner_scorer)), _null_bitmap(std::move(null_bitmap)) {}
~NullableScorer() override = default;

uint32_t advance() override { return _inner_scorer->advance(); }
uint32_t seek(uint32_t target) override { return _inner_scorer->seek(target); }
uint32_t doc() const override { return _inner_scorer->doc(); }
uint32_t size_hint() const override { return _inner_scorer->size_hint(); }
float score() override { return _inner_scorer->score(); }

bool has_null_bitmap(const NullBitmapResolver* /*resolver*/ = nullptr) override { return true; }

const roaring::Roaring* get_null_bitmap(
const NullBitmapResolver* /*resolver*/ = nullptr) override {
return _null_bitmap.get();
}

private:
ScorerPtrT _inner_scorer;
std::shared_ptr<roaring::Roaring> _null_bitmap;
};
using NullableScorerPtr = std::shared_ptr<NullableScorer<>>;

template <typename ScorerPtrT = ScorerPtr>
inline ScorerPtr make_nullable_scorer(ScorerPtrT inner_scorer, const std::string& logical_field,
const NullBitmapResolver* resolver) {
if (!inner_scorer) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"make_nullable_scorer: inner_scorer must not be null");
}

auto null_bitmap = FieldNullBitmapFetcher::fetch(resolver, logical_field, inner_scorer.get());

if (!null_bitmap || null_bitmap->isEmpty()) {
return inner_scorer;
}

return std::make_shared<NullableScorer<ScorerPtrT>>(std::move(inner_scorer),
std::move(null_bitmap));
}

} // namespace doris::segment_v2::inverted_index::query_v2
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "olap/rowset/segment_v2/index_query_context.h"
#include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/multi_phrase_weight.h"
#include "olap/rowset/segment_v2/inverted_index/query_v2/query.h"
#include "olap/rowset/segment_v2/inverted_index/similarity/bm25_similarity.h"

namespace doris::segment_v2::inverted_index::query_v2 {

class MultiPhraseQuery : public Query {
public:
MultiPhraseQuery(IndexQueryContextPtr context, std::wstring field,
std::vector<TermInfo> term_infos)
: _context(std::move(context)),
_field(std::move(field)),
_term_infos(std::move(term_infos)) {}
~MultiPhraseQuery() override = default;

WeightPtr weight(bool enable_scoring) override {
if (_term_infos.size() < 2) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"Multi-phrase query requires at least 2 terms, got {}",
_term_infos.size());
}

SimilarityPtr bm25_similarity;
if (enable_scoring) {
bm25_similarity = std::make_shared<BM25Similarity>();
std::vector<std::wstring> all_terms;
for (const auto& term_info : _term_infos) {
if (term_info.is_single_term()) {
all_terms.push_back(StringHelper::to_wstring(term_info.get_single_term()));
} else {
for (const auto& term : term_info.get_multi_terms()) {
all_terms.push_back(StringHelper::to_wstring(term));
}
}
}
bm25_similarity->for_terms(_context, _field, all_terms);
}

return std::make_shared<MultiPhraseWeight>(_context, _field, _term_infos, bm25_similarity,
enable_scoring, _nullable);
}

private:
IndexQueryContextPtr _context;

std::wstring _field;
std::vector<TermInfo> _term_infos;
bool _nullable = true;
};

} // namespace doris::segment_v2::inverted_index::query_v2
Loading
Loading