Skip to content

Commit

Permalink
refactor(search_family): Add Aggregator class (#4290)
Browse files Browse the repository at this point in the history
* refactor(search_family): Add Aggregator class

Signed-off-by: Stepan Bagritsevich <[email protected]>

* fix(aggregator_test): Fix tests failing

Signed-off-by: Stepan Bagritsevich <[email protected]>

* refactor: address comments

Signed-off-by: Stepan Bagritsevich <[email protected]>

* refactor: Restore the previous comment

Signed-off-by: Stepan Bagritsevich <[email protected]>

* refactor: address comments 2

Signed-off-by: Stepan Bagritsevich <[email protected]>

* refactor: address comments 3

Signed-off-by: Stepan Bagritsevich <[email protected]>

* fix(aggregator): Simplify comparator for the case when one of the values is not present

Signed-off-by: Stepan Bagritsevich <[email protected]>

---------

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan authored Dec 23, 2024
1 parent 8d66c25 commit 1fa9a47
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 104 deletions.
171 changes: 95 additions & 76 deletions src/server/search/aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,62 +10,99 @@ namespace dfly::aggregate {

namespace {

struct GroupStep {
PipelineResult operator()(PipelineResult result) {
// Separate items into groups
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
for (auto& value : result.values) {
groups[Extract(value)].push_back(std::move(value));
}
using ValuesList = absl::FixedArray<Value>;

// Restore DocValues and apply reducers
std::vector<DocValues> out;
while (!groups.empty()) {
auto node = groups.extract(groups.begin());
DocValues doc = Unpack(std::move(node.key()));
for (auto& reducer : reducers_) {
doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()});
}
out.push_back(std::move(doc));
}
ValuesList ExtractFieldsValues(const DocValues& dv, absl::Span<const std::string> fields) {
ValuesList out(fields.size());
for (size_t i = 0; i < fields.size(); i++) {
auto it = dv.find(fields[i]);
out[i] = (it != dv.end()) ? it->second : Value{};
}
return out;
}

absl::flat_hash_set<std::string> fields_to_print;
fields_to_print.reserve(fields_.size() + reducers_.size());
DocValues PackFields(ValuesList values, absl::Span<const std::string> fields) {
DCHECK_EQ(values.size(), fields.size());
DocValues out;
for (size_t i = 0; i < fields.size(); i++)
out[fields[i]] = std::move(values[i]);
return out;
}

for (auto& field : fields_) {
fields_to_print.insert(std::move(field));
}
for (auto& reducer : reducers_) {
fields_to_print.insert(std::move(reducer.result_field));
}
const Value kEmptyValue = Value{};

} // namespace

return {std::move(out), std::move(fields_to_print)};
void Aggregator::DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers) {
// Separate items into groups
absl::flat_hash_map<ValuesList, std::vector<DocValues>> groups;
for (auto& value : result.values) {
groups[ExtractFieldsValues(value, fields)].push_back(std::move(value));
}

absl::FixedArray<Value> Extract(const DocValues& dv) {
absl::FixedArray<Value> out(fields_.size());
for (size_t i = 0; i < fields_.size(); i++) {
auto it = dv.find(fields_[i]);
out[i] = (it != dv.end()) ? it->second : Value{};
// Restore DocValues and apply reducers
auto& values = result.values;
values.clear();
values.reserve(groups.size());
while (!groups.empty()) {
auto node = groups.extract(groups.begin());
DocValues doc = PackFields(std::move(node.key()), fields);
for (auto& reducer : reducers) {
doc[reducer.result_field] = reducer.func({reducer.source_field, node.mapped()});
}
return out;
values.push_back(std::move(doc));
}

DocValues Unpack(absl::FixedArray<Value>&& values) {
DCHECK_EQ(values.size(), fields_.size());
DocValues out;
for (size_t i = 0; i < fields_.size(); i++)
out[fields_[i]] = std::move(values[i]);
return out;
auto& fields_to_print = result.fields_to_print;
fields_to_print.clear();
fields_to_print.reserve(fields.size() + reducers.size());

for (auto& field : fields) {
fields_to_print.insert(field);
}
for (auto& reducer : reducers) {
fields_to_print.insert(reducer.result_field);
}
}

std::vector<std::string> fields_;
std::vector<Reducer> reducers_;
};
void Aggregator::DoSort(std::string_view field, bool descending) {
/*
Comparator for sorting DocValues by field.
If some of the fields is not present in the DocValues, comparator returns:
1. l_it == l.end() && r_it != r.end()
asc -> false
desc -> false
2. l_it != l.end() && r_it == r.end()
asc -> true
desc -> true
3. l_it == l.end() && r_it == r.end()
asc -> false
desc -> false
*/
auto comparator = [&](const DocValues& l, const DocValues& r) {
auto l_it = l.find(field);
auto r_it = r.find(field);

// If some of the values is not present
if (l_it == l.end() || r_it == r.end()) {
return l_it != l.end();
}

const Value kEmptyValue = Value{};
auto& lv = l_it->second;
auto& rv = r_it->second;
return !descending ? lv < rv : lv > rv;
};

} // namespace
std::sort(result.values.begin(), result.values.end(), std::move(comparator));

result.fields_to_print.insert(field);
}

void Aggregator::DoLimit(size_t offset, size_t num) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
}

const Value& ValueIterator::operator*() const {
auto it = values_.front().find(field_);
Expand Down Expand Up @@ -109,48 +146,30 @@ Reducer::Func FindReducerFunc(ReducerFunc name) {
return nullptr;
}

PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
std::vector<Reducer> reducers) {
return GroupStep{std::vector<std::string>(fields.begin(), fields.end()), std::move(reducers)};
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers) {
return [fields = std::move(fields), reducers = std::move(reducers)](Aggregator* aggregator) {
aggregator->DoGroup(fields, reducers);
};
}

PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
auto& values = result.values;

std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field);
auto it2 = r.find(field);
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
});

if (descending) {
std::reverse(values.begin(), values.end());
}

result.fields_to_print.insert(field);
return result;
AggregationStep MakeSortStep(std::string field, bool descending) {
return [field = std::move(field), descending](Aggregator* aggregator) {
aggregator->DoSort(field, descending);
};
}

PipelineStep MakeLimitStep(size_t offset, size_t num) {
return [offset, num](PipelineResult result) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
return result;
};
AggregationStep MakeLimitStep(size_t offset, size_t num) {
return [=](Aggregator* aggregator) { aggregator->DoLimit(offset, num); };
}

PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps) {
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
AggregationResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const AggregationStep> steps) {
Aggregator aggregator{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
for (auto& step : steps) {
PipelineResult step_result = step(std::move(result));
result = std::move(step_result);
step(&aggregator);
}
return result;
return aggregator.result;
}

} // namespace dfly::aggregate
35 changes: 23 additions & 12 deletions src/server/search/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,31 @@

namespace dfly::aggregate {

struct Reducer;

using Value = ::dfly::search::SortableValue;
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline

struct PipelineResult {
// DocValues sent through the pipeline
// TODO: Replace DocValues with compact linear search map instead of hash map
using DocValues = absl::flat_hash_map<std::string_view, Value>;

struct AggregationResult {
// Values to be passed to the next step
// TODO: Replace DocValues with compact linear search map instead of hash map
std::vector<DocValues> values;

// Fields from values to be printed
absl::flat_hash_set<std::string> fields_to_print;
absl::flat_hash_set<std::string_view> fields_to_print;
};

struct Aggregator {
void DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers);
void DoSort(std::string_view field, bool descending = false);
void DoLimit(size_t offset, size_t num);

AggregationResult result;
};

using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.
using AggregationStep = std::function<void(Aggregator*)>; // Group, Sort, etc.

// Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
// Extra clumsy for STL compatibility!
Expand Down Expand Up @@ -79,18 +91,17 @@ enum class ReducerFunc { COUNT, COUNT_DISTINCT, SUM, AVG, MAX, MIN };
Reducer::Func FindReducerFunc(ReducerFunc name);

// Make `GROUPBY [fields...]` with REDUCE step
PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
std::vector<Reducer> reducers);
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers);

// Make `SORTBY field [DESC]` step
PipelineStep MakeSortStep(std::string_view field, bool descending = false);
AggregationStep MakeSortStep(std::string field, bool descending = false);

// Make `LIMIT offset num` step
PipelineStep MakeLimitStep(size_t offset, size_t num);
AggregationStep MakeLimitStep(size_t offset, size_t num);

// Process values with given steps
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps);
AggregationResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const AggregationStep> steps);

} // namespace dfly::aggregate
16 changes: 10 additions & 6 deletions src/server/search/aggregator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ namespace dfly::aggregate {

using namespace std::string_literals;

using StepsList = std::vector<AggregationStep>;

TEST(AggregatorTest, Sort) {
std::vector<DocValues> values = {
DocValues{{"a", 1.0}},
DocValues{{"a", 0.5}},
DocValues{{"a", 1.5}},
};
PipelineStep steps[] = {MakeSortStep("a", false)};
StepsList steps = {MakeSortStep("a", false)};

auto result = Process(values, {"a"}, steps);

Expand All @@ -32,7 +34,8 @@ TEST(AggregatorTest, Limit) {
DocValues{{"i", 3.0}},
DocValues{{"i", 4.0}},
};
PipelineStep steps[] = {MakeLimitStep(1, 2)};

StepsList steps = {MakeLimitStep(1, 2)};

auto result = Process(values, {"i"}, steps);

Expand All @@ -49,8 +52,8 @@ TEST(AggregatorTest, SimpleGroup) {
DocValues{{"i", 4.0}, {"tag", "even"}},
};

std::string_view fields[] = {"tag"};
PipelineStep steps[] = {MakeGroupStep(fields, {})};
std::vector<std::string> fields = {"tag"};
StepsList steps = {MakeGroupStep(std::move(fields), {})};

auto result = Process(values, {"i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);
Expand All @@ -72,13 +75,14 @@ TEST(AggregatorTest, GroupWithReduce) {
});
}

std::string_view fields[] = {"tag"};
std::vector<std::string> fields = {"tag"};
std::vector<Reducer> reducers = {
Reducer{"", "count", FindReducerFunc(ReducerFunc::COUNT)},
Reducer{"i", "sum-i", FindReducerFunc(ReducerFunc::SUM)},
Reducer{"half-i", "distinct-hi", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)},
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};

StepsList steps = {MakeGroupStep(std::move(fields), std::move(reducers))};

auto result = Process(values, {"i", "half-i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);
Expand Down
2 changes: 1 addition & 1 deletion src/server/search/doc_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ struct AggregateParams {
search::QueryParams params;

std::optional<SearchFieldsList> load_fields;
std::vector<aggregate::PipelineStep> steps;
std::vector<aggregate::AggregationStep> steps;
};

// Stores basic info about a document index.
Expand Down
Loading

0 comments on commit 1fa9a47

Please sign in to comment.