diff --git a/src/server/search/aggregator.cc b/src/server/search/aggregator.cc index 09ccb841eafa..28d943c21499 100644 --- a/src/server/search/aggregator.cc +++ b/src/server/search/aggregator.cc @@ -65,9 +65,9 @@ void Aggregator::DoGroup(absl::Span fields, absl::Span false @@ -80,22 +80,41 @@ void Aggregator::DoSort(std::string_view field, bool descending) { desc -> false */ auto comparator = [&](const DocValues& l, const DocValues& r) { - auto l_it = l.find(field); - auto r_it = r.find(field); - - // If some of the values is not present - if (l_it == l.end() || r_it == r.end()) { - return l_it != l.end(); + for (const auto& [field, order] : sort_params.fields) { + auto l_it = l.find(field); + auto r_it = r.find(field); + + // If some of the values is not present + if (l_it == l.end() || r_it == r.end()) { + if (l_it == l.end() && r_it == r.end()) { + continue; + } + return l_it != l.end(); + } + + const auto& lv = l_it->second; + const auto& rv = r_it->second; + if (lv == rv) { + continue; + } + return order == SortParams::SortOrder::ASC ? lv < rv : lv > rv; } - - auto& lv = l_it->second; - auto& rv = r_it->second; - return !descending ? lv < rv : lv > rv; + return false; }; - std::sort(result.values.begin(), result.values.end(), std::move(comparator)); + auto& values = result.values; + if (sort_params.SortAll()) { + std::sort(values.begin(), values.end(), comparator); + } else { + DCHECK_GE(sort_params.max, 0); + const size_t limit = std::min(values.size(), size_t(sort_params.max)); + std::partial_sort(values.begin(), values.begin() + limit, values.end(), comparator); + values.resize(limit); + } - result.fields_to_print.insert(field); + for (auto& field : sort_params.fields) { + result.fields_to_print.insert(field.first); + } } void Aggregator::DoLimit(size_t offset, size_t num) { @@ -152,10 +171,8 @@ AggregationStep MakeGroupStep(std::vector fields, std::vectorDoSort(field, descending); - }; +AggregationStep MakeSortStep(SortParams sort_params) { + return [params = std::move(sort_params)](Aggregator* aggregator) { aggregator->DoSort(params); }; } AggregationStep MakeLimitStep(size_t offset, size_t num) { diff --git a/src/server/search/aggregator.h b/src/server/search/aggregator.h index a298735182f4..fe4dbbfe237a 100644 --- a/src/server/search/aggregator.h +++ b/src/server/search/aggregator.h @@ -33,9 +33,27 @@ struct AggregationResult { absl::flat_hash_set fields_to_print; }; +struct SortParams { + enum class SortOrder { ASC, DESC }; + + constexpr static int64_t kSortAll = -1; + + bool SortAll() const { + return max == kSortAll; + } + + /* Fields to sort by. If multiple fields are provided, sorting works hierarchically: + - First, the i-th field is compared. + - If the i-th field values are equal, the (i + 1)-th field is compared, and so on. */ + absl::InlinedVector, 2> fields; + /* Max number of elements to include in the sorted result. + If set, only the first [max] elements are fully sorted using partial_sort. */ + int64_t max = kSortAll; +}; + struct Aggregator { void DoGroup(absl::Span fields, absl::Span reducers); - void DoSort(std::string_view field, bool descending = false); + void DoSort(const SortParams& sort_params); void DoLimit(size_t offset, size_t num); AggregationResult result; @@ -94,7 +112,7 @@ Reducer::Func FindReducerFunc(ReducerFunc name); AggregationStep MakeGroupStep(std::vector fields, std::vector reducers); // Make `SORTBY field [DESC]` step -AggregationStep MakeSortStep(std::string field, bool descending = false); +AggregationStep MakeSortStep(SortParams sort_params); // Make `LIMIT offset num` step AggregationStep MakeLimitStep(size_t offset, size_t num); diff --git a/src/server/search/aggregator_test.cc b/src/server/search/aggregator_test.cc index a9f9544ce3b7..a0adaffee309 100644 --- a/src/server/search/aggregator_test.cc +++ b/src/server/search/aggregator_test.cc @@ -18,7 +18,10 @@ TEST(AggregatorTest, Sort) { DocValues{{"a", 0.5}}, DocValues{{"a", 1.5}}, }; - StepsList steps = {MakeSortStep("a", false)}; + + SortParams params; + params.fields.emplace_back("a", SortParams::SortOrder::ASC); + StepsList steps = {MakeSortStep(std::move(params))}; auto result = Process(values, {"a"}, steps); diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index 6547642310f4..f7c9c8de47a4 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -306,6 +306,42 @@ optional ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyB return params; } +std::optional ParseAggregatorSortParams(CmdArgParser* parser) { + using SordOrder = aggregate::SortParams::SortOrder; + + size_t strings_num = parser->Next(); + + aggregate::SortParams sort_params; + sort_params.fields.reserve(strings_num / 2); + + while (parser->HasNext() && strings_num > 0) { + // TODO: Throw an error if the field has no '@' sign at the beginning + std::string_view parsed_field = ParseFieldWithAtSign(parser); + strings_num--; + + SordOrder sord_order = SordOrder::ASC; + if (strings_num > 0) { + auto order = parser->TryMapNext("ASC", SordOrder::ASC, "DESC", SordOrder::DESC); + if (order) { + sord_order = order.value(); + strings_num--; + } + } + + sort_params.fields.emplace_back(parsed_field, sord_order); + } + + if (strings_num) { + return std::nullopt; + } + + if (parser->Check("MAX")) { + sort_params.max = parser->Next(); + } + + return sort_params; +} + optional ParseAggregatorParamsOrReply(CmdArgParser parser, SinkReplyBuilder* builder) { AggregateParams params; @@ -372,11 +408,13 @@ optional ParseAggregatorParamsOrReply(CmdArgParser parser, // SORTBY nargs if (parser.Check("SORTBY")) { - parser.ExpectTag("1"); - string_view field = parser.Next(); - bool desc = bool(parser.Check("DESC")); + auto sort_params = ParseAggregatorSortParams(&parser); + if (!sort_params) { + builder->SendError("bad arguments for SORTBY: specified invalid number of strings"); + return nullopt; + } - params.steps.push_back(aggregate::MakeSortStep(std::string{field}, desc)); + params.steps.push_back(aggregate::MakeSortStep(std::move(sort_params).value())); continue; } diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index 0e1aebbd97d0..05244a3a3971 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -1680,4 +1680,149 @@ TEST_F(SearchFamilyTest, AggregateResultFields) { IsMap(), IsMap())); } +TEST_F(SearchFamilyTest, AggregateSortByJson) { + Run({"JSON.SET", "j1", "$", R"({"name": "first", "number": 1200, "group": "first"})"}); + Run({"JSON.SET", "j2", "$", R"({"name": "second", "number": 800, "group": "first"})"}); + Run({"JSON.SET", "j3", "$", R"({"name": "third", "number": 300, "group": "first"})"}); + Run({"JSON.SET", "j4", "$", R"({"name": "fourth", "number": 400, "group": "second"})"}); + Run({"JSON.SET", "j5", "$", R"({"name": "fifth", "number": 900, "group": "second"})"}); + Run({"JSON.SET", "j6", "$", R"({"name": "sixth", "number": 300, "group": "first"})"}); + Run({"JSON.SET", "j7", "$", R"({"name": "seventh", "number": 400, "group": "second"})"}); + Run({"JSON.SET", "j8", "$", R"({"name": "eighth", "group": "first"})"}); + Run({"JSON.SET", "j9", "$", R"({"name": "ninth", "group": "second"})"}); + + Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.name", "AS", "name", "TEXT", "$.number", + "AS", "number", "NUMERIC", "$.group", "AS", "group", "TAG"}); + + // Test sorting by name (DESC) and number (ASC) + auto resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "4", "@name", "DESC", "@number", "ASC"}); + EXPECT_THAT(resp, IsUnordArrayWithSize( + IsMap("name", "\"third\"", "number", "300"), + IsMap("name", "\"sixth\"", "number", "300"), + IsMap("name", "\"seventh\"", "number", "400"), + IsMap("name", "\"second\"", "number", "800"), IsMap("name", "\"ninth\""), + IsMap("name", "\"fourth\"", "number", "400"), + IsMap("name", "\"first\"", "number", "1200"), + IsMap("name", "\"fifth\"", "number", "900"), IsMap("name", "\"eighth\""))); + + // Test sorting by name (ASC) and number (DESC) + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "4", "@name", "ASC", "@number", "DESC"}); + EXPECT_THAT(resp, IsUnordArrayWithSize( + IsMap("name", "\"eighth\""), IsMap("name", "\"fifth\"", "number", "900"), + IsMap("name", "\"first\"", "number", "1200"), + IsMap("name", "\"fourth\"", "number", "400"), IsMap("name", "\"ninth\""), + IsMap("name", "\"second\"", "number", "800"), + IsMap("name", "\"seventh\"", "number", "400"), + IsMap("name", "\"sixth\"", "number", "300"), + IsMap("name", "\"third\"", "number", "300"))); + + // Test sorting by group (ASC), number (DESC), and name + resp = Run( + {"FT.AGGREGATE", "index", "*", "SORTBY", "5", "@group", "ASC", "@number", "DESC", "@name"}); + EXPECT_THAT(resp, IsUnordArrayWithSize( + IsMap("group", "\"first\"", "number", "1200", "name", "\"first\""), + IsMap("group", "\"first\"", "number", "800", "name", "\"second\""), + IsMap("group", "\"first\"", "number", "300", "name", "\"sixth\""), + IsMap("group", "\"first\"", "number", "300", "name", "\"third\""), + IsMap("group", "\"first\"", "name", "\"eighth\""), + IsMap("group", "\"second\"", "number", "900", "name", "\"fifth\""), + IsMap("group", "\"second\"", "number", "400", "name", "\"fourth\""), + IsMap("group", "\"second\"", "number", "400", "name", "\"seventh\""), + IsMap("group", "\"second\"", "name", "\"ninth\""))); + + // Test sorting by number (ASC), group (DESC), and name + resp = Run( + {"FT.AGGREGATE", "index", "*", "SORTBY", "5", "@number", "ASC", "@group", "DESC", "@name"}); + EXPECT_THAT(resp, IsUnordArrayWithSize( + IsMap("number", "300", "group", "\"first\"", "name", "\"sixth\""), + IsMap("number", "300", "group", "\"first\"", "name", "\"third\""), + IsMap("number", "400", "group", "\"second\"", "name", "\"fourth\""), + IsMap("number", "400", "group", "\"second\"", "name", "\"seventh\""), + IsMap("number", "800", "group", "\"first\"", "name", "\"second\""), + IsMap("number", "900", "group", "\"second\"", "name", "\"fifth\""), + IsMap("number", "1200", "group", "\"first\"", "name", "\"first\""), + IsMap("group", "\"second\"", "name", "\"ninth\""), + IsMap("group", "\"first\"", "name", "\"eighth\""))); + + // Test sorting with MAX 3 + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@number", "MAX", "3"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "300"), IsMap("number", "300"), + IsMap("number", "400"))); + + // Test sorting with MAX 3 + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "2", "@number", "DESC", "MAX", "3"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "1200"), IsMap("number", "900"), + IsMap("number", "800"))); + + // Test sorting by number (ASC) with MAX 999 + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@number", "MAX", "999"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "300"), IsMap("number", "300"), + IsMap("number", "400"), IsMap("number", "400"), + IsMap("number", "800"), IsMap("number", "900"), + IsMap("number", "1200"), IsMap(), IsMap())); + + // Test sorting by name and number (DESC) + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "3", "@name", "@number", "DESC"}); + EXPECT_THAT(resp, IsUnordArrayWithSize( + IsMap("name", "\"eighth\""), IsMap("name", "\"fifth\"", "number", "900"), + IsMap("name", "\"first\"", "number", "1200"), + IsMap("name", "\"fourth\"", "number", "400"), IsMap("name", "\"ninth\""), + IsMap("name", "\"second\"", "number", "800"), + IsMap("name", "\"seventh\"", "number", "400"), + IsMap("name", "\"sixth\"", "number", "300"), + IsMap("name", "\"third\"", "number", "300"))); + + // Test SORTBY with MAX, GROUPBY, and REDUCE COUNT + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "3", "GROUPBY", "1", + "@number", "REDUCE", "COUNT", "0", "AS", "count"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("number", "900", "count", "1"), + IsMap("number", ArgType(RespExpr::NIL), "count", "1"), + IsMap("number", "1200", "count", "1"))); + + // Test SORTBY with MAX, GROUPBY (0 fields), and REDUCE COUNT + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "3", "GROUPBY", "0", + "REDUCE", "COUNT", "0", "AS", "count"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("count", "3"))); +} + +TEST_F(SearchFamilyTest, AggregateSortByParsingErrors) { + Run({"JSON.SET", "j1", "$", R"({"name": "first", "number": 1200, "group": "first"})"}); + Run({"JSON.SET", "j2", "$", R"({"name": "second", "number": 800, "group": "first"})"}); + Run({"JSON.SET", "j3", "$", R"({"name": "third", "number": 300, "group": "first"})"}); + Run({"JSON.SET", "j4", "$", R"({"name": "fourth", "number": 400, "group": "second"})"}); + Run({"JSON.SET", "j5", "$", R"({"name": "fifth", "number": 900, "group": "second"})"}); + Run({"JSON.SET", "j6", "$", R"({"name": "sixth", "number": 300, "group": "first"})"}); + Run({"JSON.SET", "j7", "$", R"({"name": "seventh", "number": 400, "group": "second"})"}); + Run({"JSON.SET", "j8", "$", R"({"name": "eighth", "group": "first"})"}); + Run({"JSON.SET", "j9", "$", R"({"name": "ninth", "group": "second"})"}); + + Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.name", "AS", "name", "TEXT", "$.number", + "AS", "number", "NUMERIC", "$.group", "AS", "group", "TAG"}); + + // Test SORTBY with invalid argument count + auto resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "999", "@name", "@number", "DESC"}); + EXPECT_THAT(resp, ErrArg("bad arguments for SORTBY: specified invalid number of strings")); + + // Test SORTBY with negative argument count + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "-3", "@name", "@number", "DESC"}); + EXPECT_THAT(resp, ErrArg("value is not an integer or out of range")); + + // Test MAX with invalid value + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX", "-10"}); + EXPECT_THAT(resp, ErrArg("value is not an integer or out of range")); + + // Test MAX without a value + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@name", "MAX"}); + EXPECT_THAT(resp, ErrArg("syntax error")); + + // Test SORTBY with a non-existing field + /* Temporary unsupported + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "@nonexistingfield"}); + EXPECT_THAT(resp, ErrArg("Property `nonexistingfield` not loaded nor in schema")); */ + + // Test SORTBY with an invalid value + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "notvalue", "@name"}); + EXPECT_THAT(resp, ErrArg("value is not an integer or out of range")); +} + } // namespace dfly