Skip to content

Commit

Permalink
fix(search_family): Support multiple fields in SORTBY option in the F…
Browse files Browse the repository at this point in the history
…T.AGGREGATE command. SECOND PR (#4232)

fix(search_family): Support multiple fields in SORTBY option in the FT.AGGREGATE command

fixes dragonfly#3631

Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan authored Dec 24, 2024
1 parent 3c7e312 commit aeeb625
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 25 deletions.
53 changes: 35 additions & 18 deletions src/server/search/aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ void Aggregator::DoGroup(absl::Span<const std::string> fields, absl::Span<const
}
}

void Aggregator::DoSort(std::string_view field, bool descending) {
void Aggregator::DoSort(const SortParams& sort_params) {
/*
Comparator for sorting DocValues by field.
Comparator for sorting DocValues by fields.
If some of the fields is not present in the DocValues, comparator returns:
1. l_it == l.end() && r_it != r.end()
asc -> false
Expand All @@ -80,22 +80,41 @@ void Aggregator::DoSort(std::string_view field, bool descending) {
desc -> false
*/
auto comparator = [&](const DocValues& l, const DocValues& r) {
auto l_it = l.find(field);
auto r_it = r.find(field);

// If some of the values is not present
if (l_it == l.end() || r_it == r.end()) {
return l_it != l.end();
for (const auto& [field, order] : sort_params.fields) {
auto l_it = l.find(field);
auto r_it = r.find(field);

// If some of the values is not present
if (l_it == l.end() || r_it == r.end()) {
if (l_it == l.end() && r_it == r.end()) {
continue;
}
return l_it != l.end();
}

const auto& lv = l_it->second;
const auto& rv = r_it->second;
if (lv == rv) {
continue;
}
return order == SortParams::SortOrder::ASC ? lv < rv : lv > rv;
}

auto& lv = l_it->second;
auto& rv = r_it->second;
return !descending ? lv < rv : lv > rv;
return false;
};

std::sort(result.values.begin(), result.values.end(), std::move(comparator));
auto& values = result.values;
if (sort_params.SortAll()) {
std::sort(values.begin(), values.end(), comparator);
} else {
DCHECK_GE(sort_params.max, 0);
const size_t limit = std::min(values.size(), size_t(sort_params.max));
std::partial_sort(values.begin(), values.begin() + limit, values.end(), comparator);
values.resize(limit);
}

result.fields_to_print.insert(field);
for (auto& field : sort_params.fields) {
result.fields_to_print.insert(field.first);
}
}

void Aggregator::DoLimit(size_t offset, size_t num) {
Expand Down Expand Up @@ -152,10 +171,8 @@ AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reduc
};
}

AggregationStep MakeSortStep(std::string field, bool descending) {
return [field = std::move(field), descending](Aggregator* aggregator) {
aggregator->DoSort(field, descending);
};
AggregationStep MakeSortStep(SortParams sort_params) {
return [params = std::move(sort_params)](Aggregator* aggregator) { aggregator->DoSort(params); };
}

AggregationStep MakeLimitStep(size_t offset, size_t num) {
Expand Down
22 changes: 20 additions & 2 deletions src/server/search/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,27 @@ struct AggregationResult {
absl::flat_hash_set<std::string_view> fields_to_print;
};

struct SortParams {
enum class SortOrder { ASC, DESC };

constexpr static int64_t kSortAll = -1;

bool SortAll() const {
return max == kSortAll;
}

/* Fields to sort by. If multiple fields are provided, sorting works hierarchically:
- First, the i-th field is compared.
- If the i-th field values are equal, the (i + 1)-th field is compared, and so on. */
absl::InlinedVector<std::pair<std::string, SortOrder>, 2> fields;
/* Max number of elements to include in the sorted result.
If set, only the first [max] elements are fully sorted using partial_sort. */
int64_t max = kSortAll;
};

struct Aggregator {
void DoGroup(absl::Span<const std::string> fields, absl::Span<const Reducer> reducers);
void DoSort(std::string_view field, bool descending = false);
void DoSort(const SortParams& sort_params);
void DoLimit(size_t offset, size_t num);

AggregationResult result;
Expand Down Expand Up @@ -94,7 +112,7 @@ Reducer::Func FindReducerFunc(ReducerFunc name);
AggregationStep MakeGroupStep(std::vector<std::string> fields, std::vector<Reducer> reducers);

// Make `SORTBY field [DESC]` step
AggregationStep MakeSortStep(std::string field, bool descending = false);
AggregationStep MakeSortStep(SortParams sort_params);

// Make `LIMIT offset num` step
AggregationStep MakeLimitStep(size_t offset, size_t num);
Expand Down
5 changes: 4 additions & 1 deletion src/server/search/aggregator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ TEST(AggregatorTest, Sort) {
DocValues{{"a", 0.5}},
DocValues{{"a", 1.5}},
};
StepsList steps = {MakeSortStep("a", false)};

SortParams params;
params.fields.emplace_back("a", SortParams::SortOrder::ASC);
StepsList steps = {MakeSortStep(std::move(params))};

auto result = Process(values, {"a"}, steps);

Expand Down
46 changes: 42 additions & 4 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,42 @@ optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyB
return params;
}

std::optional<aggregate::SortParams> ParseAggregatorSortParams(CmdArgParser* parser) {
using SordOrder = aggregate::SortParams::SortOrder;

size_t strings_num = parser->Next<size_t>();

aggregate::SortParams sort_params;
sort_params.fields.reserve(strings_num / 2);

while (parser->HasNext() && strings_num > 0) {
// TODO: Throw an error if the field has no '@' sign at the beginning
std::string_view parsed_field = ParseFieldWithAtSign(parser);
strings_num--;

SordOrder sord_order = SordOrder::ASC;
if (strings_num > 0) {
auto order = parser->TryMapNext("ASC", SordOrder::ASC, "DESC", SordOrder::DESC);
if (order) {
sord_order = order.value();
strings_num--;
}
}

sort_params.fields.emplace_back(parsed_field, sord_order);
}

if (strings_num) {
return std::nullopt;
}

if (parser->Check("MAX")) {
sort_params.max = parser->Next<size_t>();
}

return sort_params;
}

optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,
SinkReplyBuilder* builder) {
AggregateParams params;
Expand Down Expand Up @@ -372,11 +408,13 @@ optional<AggregateParams> ParseAggregatorParamsOrReply(CmdArgParser parser,

// SORTBY nargs
if (parser.Check("SORTBY")) {
parser.ExpectTag("1");
string_view field = parser.Next();
bool desc = bool(parser.Check("DESC"));
auto sort_params = ParseAggregatorSortParams(&parser);
if (!sort_params) {
builder->SendError("bad arguments for SORTBY: specified invalid number of strings");
return nullopt;
}

params.steps.push_back(aggregate::MakeSortStep(std::string{field}, desc));
params.steps.push_back(aggregate::MakeSortStep(std::move(sort_params).value()));
continue;
}

Expand Down
145 changes: 145 additions & 0 deletions src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1680,4 +1680,149 @@ TEST_F(SearchFamilyTest, AggregateResultFields) {
IsMap(), IsMap()));
}

TEST_F(SearchFamilyTest, AggregateSortByJson) {
Run({"JSON.SET", "j1", "$", R"({"name": "first", "number": 1200, "group": "first"})"});
Run({"JSON.SET", "j2", "$", R"({"name": "second", "number": 800, "group": "first"})"});
Run({"JSON.SET", "j3", "$", R"({"name": "third", "number": 300, "group": "first"})"});
Run({"JSON.SET", "j4", "$", R"({"name": "fourth", "number": 400, "group": "second"})"});
Run({"JSON.SET", "j5", "$", R"({"name": "fifth", "number": 900, "group": "second"})"});
Run({"JSON.SET", "j6", "$", R"({"name": "sixth", "number": 300, "group": "first"})"});
Run({"JSON.SET", "j7", "$", R"({"name": "seventh", "number": 400, "group": "second"})"});
Run({"JSON.SET", "j8", "$", R"({"name": "eighth", "group": "first"})"});
Run({"JSON.SET", "j9", "$", R"({"name": "ninth", "group": "second"})"});

Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.name", "AS", "name", "TEXT", "$.number",
"AS", "number", "NUMERIC", "$.group", "AS", "group", "TAG"});

// Test sorting by name (DESC) and number (ASC)
auto resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "4", "@name", "DESC", "@number", "ASC"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("name", "\"third\"", "number", "300"),
IsMap("name", "\"sixth\"", "number", "300"),
IsMap("name", "\"seventh\"", "number", "400"),
IsMap("name", "\"second\"", "number", "800"), IsMap("name", "\"ninth\""),
IsMap("name", "\"fourth\"", "number", "400"),
IsMap("name", "\"first\"", "number", "1200"),
IsMap("name", "\"fifth\"", "number", "900"), IsMap("name", "\"eighth\"")));

// Test sorting by name (ASC) and number (DESC)
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "4", "@name", "ASC", "@number", "DESC"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("name", "\"eighth\""), IsMap("name", "\"fifth\"", "number", "900"),
IsMap("name", "\"first\"", "number", "1200"),
IsMap("name", "\"fourth\"", "number", "400"), IsMap("name", "\"ninth\""),
IsMap("name", "\"second\"", "number", "800"),
IsMap("name", "\"seventh\"", "number", "400"),
IsMap("name", "\"sixth\"", "number", "300"),
IsMap("name", "\"third\"", "number", "300")));

// Test sorting by group (ASC), number (DESC), and name
resp = Run(
{"FT.AGGREGATE", "index", "*", "SORTBY", "5", "@group", "ASC", "@number", "DESC", "@name"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("group", "\"first\"", "number", "1200", "name", "\"first\""),
IsMap("group", "\"first\"", "number", "800", "name", "\"second\""),
IsMap("group", "\"first\"", "number", "300", "name", "\"sixth\""),
IsMap("group", "\"first\"", "number", "300", "name", "\"third\""),
IsMap("group", "\"first\"", "name", "\"eighth\""),
IsMap("group", "\"second\"", "number", "900", "name", "\"fifth\""),
IsMap("group", "\"second\"", "number", "400", "name", "\"fourth\""),
IsMap("group", "\"second\"", "number", "400", "name", "\"seventh\""),
IsMap("group", "\"second\"", "name", "\"ninth\"")));

// Test sorting by number (ASC), group (DESC), and name
resp = Run(
{"FT.AGGREGATE", "index", "*", "SORTBY", "5", "@number", "ASC", "@group", "DESC", "@name"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("number", "300", "group", "\"first\"", "name", "\"sixth\""),
IsMap("number", "300", "group", "\"first\"", "name", "\"third\""),
IsMap("number", "400", "group", "\"second\"", "name", "\"fourth\""),
IsMap("number", "400", "group", "\"second\"", "name", "\"seventh\""),
IsMap("number", "800", "group", "\"first\"", "name", "\"second\""),
IsMap("number", "900", "group", "\"second\"", "name", "\"fifth\""),
IsMap("number", "1200", "group", "\"first\"", "name", "\"first\""),
IsMap("group", "\"second\"", "name", "\"ninth\""),
IsMap("group", "\"first\"", "name", "\"eighth\"")));

// Test sorting with MAX 3
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@number", "MAX", "3"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "300"), IsMap("number", "300"),
IsMap("number", "400")));

// Test sorting with MAX 3
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "2", "@number", "DESC", "MAX", "3"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "1200"), IsMap("number", "900"),
IsMap("number", "800")));

// Test sorting by number (ASC) with MAX 999
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@number", "MAX", "999"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "300"), IsMap("number", "300"),
IsMap("number", "400"), IsMap("number", "400"),
IsMap("number", "800"), IsMap("number", "900"),
IsMap("number", "1200"), IsMap(), IsMap()));

// Test sorting by name and number (DESC)
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "3", "@name", "@number", "DESC"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("name", "\"eighth\""), IsMap("name", "\"fifth\"", "number", "900"),
IsMap("name", "\"first\"", "number", "1200"),
IsMap("name", "\"fourth\"", "number", "400"), IsMap("name", "\"ninth\""),
IsMap("name", "\"second\"", "number", "800"),
IsMap("name", "\"seventh\"", "number", "400"),
IsMap("name", "\"sixth\"", "number", "300"),
IsMap("name", "\"third\"", "number", "300")));

// Test SORTBY with MAX, GROUPBY, and REDUCE COUNT
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "3", "GROUPBY", "1",
"@number", "REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "900", "count", "1"),
IsMap("number", ArgType(RespExpr::NIL), "count", "1"),
IsMap("number", "1200", "count", "1")));

// Test SORTBY with MAX, GROUPBY (0 fields), and REDUCE COUNT
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "3", "GROUPBY", "0",
"REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("count", "3")));
}

TEST_F(SearchFamilyTest, AggregateSortByParsingErrors) {
Run({"JSON.SET", "j1", "$", R"({"name": "first", "number": 1200, "group": "first"})"});
Run({"JSON.SET", "j2", "$", R"({"name": "second", "number": 800, "group": "first"})"});
Run({"JSON.SET", "j3", "$", R"({"name": "third", "number": 300, "group": "first"})"});
Run({"JSON.SET", "j4", "$", R"({"name": "fourth", "number": 400, "group": "second"})"});
Run({"JSON.SET", "j5", "$", R"({"name": "fifth", "number": 900, "group": "second"})"});
Run({"JSON.SET", "j6", "$", R"({"name": "sixth", "number": 300, "group": "first"})"});
Run({"JSON.SET", "j7", "$", R"({"name": "seventh", "number": 400, "group": "second"})"});
Run({"JSON.SET", "j8", "$", R"({"name": "eighth", "group": "first"})"});
Run({"JSON.SET", "j9", "$", R"({"name": "ninth", "group": "second"})"});

Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.name", "AS", "name", "TEXT", "$.number",
"AS", "number", "NUMERIC", "$.group", "AS", "group", "TAG"});

// Test SORTBY with invalid argument count
auto resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "999", "@name", "@number", "DESC"});
EXPECT_THAT(resp, ErrArg("bad arguments for SORTBY: specified invalid number of strings"));

// Test SORTBY with negative argument count
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "-3", "@name", "@number", "DESC"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));

// Test MAX with invalid value
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "-10"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));

// Test MAX without a value
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX"});
EXPECT_THAT(resp, ErrArg("syntax error"));

// Test SORTBY with a non-existing field
/* Temporary unsupported
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@nonexistingfield"});
EXPECT_THAT(resp, ErrArg("Property `nonexistingfield` not loaded nor in schema")); */

// Test SORTBY with an invalid value
resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "notvalue", "@name"});
EXPECT_THAT(resp, ErrArg("value is not an integer or out of range"));
}

} // namespace dfly

0 comments on commit aeeb625

Please sign in to comment.