Skip to content

Commit

Permalink
Add TopNRank optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
aditi-pandit committed Nov 15, 2024
1 parent f37dc00 commit c7926e5
Show file tree
Hide file tree
Showing 11 changed files with 991 additions and 239 deletions.
39 changes: 39 additions & 0 deletions velox/core/PlanNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1747,15 +1747,49 @@ PlanNodePtr RowNumberNode::create(const folly::dynamic& obj, void* context) {
source);
}

namespace {
std::unordered_map<TopNRowNumberNode::RankFunction, std::string>
rankFunctionNames() {
return {
{TopNRowNumberNode::RankFunction::kRowNumber, "row_number"},
{TopNRowNumberNode::RankFunction::kRank, "rank"},
{TopNRowNumberNode::RankFunction::kDenseRank, "dense_rank"},
};
}
} // namespace

// static
const char* TopNRowNumberNode::rankFunctionName(
TopNRowNumberNode::RankFunction function) {
static const auto kFunctionNames = rankFunctionNames();
auto it = kFunctionNames.find(function);
VELOX_CHECK(
it != kFunctionNames.end(),
"Invalid window type {}",
static_cast<int>(function));
return it->second.c_str();
}

// static
TopNRowNumberNode::RankFunction TopNRowNumberNode::rankFunctionFromName(
const std::string& name) {
static const auto kFunctionNames = invertMap(rankFunctionNames());
auto it = kFunctionNames.find(name);
VELOX_CHECK(it != kFunctionNames.end(), "Invalid rank function " + name);
return it->second;
}

TopNRowNumberNode::TopNRowNumberNode(
PlanNodeId id,
RankFunction function,
std::vector<FieldAccessTypedExprPtr> partitionKeys,
std::vector<FieldAccessTypedExprPtr> sortingKeys,
std::vector<SortOrder> sortingOrders,
const std::optional<std::string>& rowNumberColumnName,
int32_t limit,
PlanNodePtr source)
: PlanNode(std::move(id)),
function_(function),
partitionKeys_{std::move(partitionKeys)},
sortingKeys_{std::move(sortingKeys)},
sortingOrders_{std::move(sortingOrders)},
Expand Down Expand Up @@ -1793,6 +1827,8 @@ TopNRowNumberNode::TopNRowNumberNode(
}

void TopNRowNumberNode::addDetails(std::stringstream& stream) const {
stream << rankFunctionName(function_) << " ";

if (!partitionKeys_.empty()) {
stream << "partition by (";
addFields(stream, partitionKeys_);
Expand All @@ -1808,6 +1844,7 @@ void TopNRowNumberNode::addDetails(std::stringstream& stream) const {

folly::dynamic TopNRowNumberNode::serialize() const {
auto obj = PlanNode::serialize();
obj["function"] = rankFunctionName(function_);
obj["partitionKeys"] = ISerializable::serialize(partitionKeys_);
obj["sortingKeys"] = ISerializable::serialize(sortingKeys_);
obj["sortingOrders"] = serializeSortingOrders(sortingOrders_);
Expand All @@ -1823,6 +1860,7 @@ PlanNodePtr TopNRowNumberNode::create(
const folly::dynamic& obj,
void* context) {
auto source = deserializeSingleSource(obj, context);
auto function = rankFunctionFromName(obj["function"].asString());
auto partitionKeys = deserializeFields(obj["partitionKeys"], context);
auto sortingKeys = deserializeFields(obj["sortingKeys"], context);

Expand All @@ -1835,6 +1873,7 @@ PlanNodePtr TopNRowNumberNode::create(

return std::make_shared<TopNRowNumberNode>(
deserializePlanNodeId(obj),
function,
partitionKeys,
sortingKeys,
sortingOrders,
Expand Down
58 changes: 53 additions & 5 deletions velox/core/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -2339,31 +2339,73 @@ class MarkDistinctNode : public PlanNode {
const RowTypePtr outputType_;
};

/// Optimized version of a WindowNode for a single row_number function with a
/// limit over sorted partitions.
/// The output of this node contains all input columns followed by an optional
/// Optimized version of a WindowNode for a single row_number, rank or
/// dense_rank function with a limit over sorted partitions. The output of this
/// node contains all input columns followed by an optional
/// 'rowNumberColumnName' BIGINT column.
class TopNRowNumberNode : public PlanNode {
public:
enum class RankFunction {
kRowNumber,
kRank,
kDenseRank,
};

static const char* rankFunctionName(TopNRowNumberNode::RankFunction function);

static RankFunction rankFunctionFromName(const std::string& name);

/// @param rankFunction RanksFunction (row_number, rank, dense_rank) for TopN.
/// @param partitionKeys Partitioning keys. May be empty.
/// @param sortingKeys Sorting keys. May not be empty and may not intersect
/// with 'partitionKeys'.
/// @param sortingOrders Sorting orders, one per sorting key.
/// @param rowNumberColumnName Optional name of the column containing row
/// numbers. If not specified, the output doesn't include 'row number'
/// column. This is used when computing partial results.
/// numbers (or rank and dense_rank). If not specified, the output doesn't
/// include 'row number' column. This is used when computing partial results.
/// @param limit Per-partition limit. The number of
/// rows produced by this node will not exceed this value for any given
/// partition. Extra rows will be dropped.
TopNRowNumberNode(
PlanNodeId id,
RankFunction function,
std::vector<FieldAccessTypedExprPtr> partitionKeys,
std::vector<FieldAccessTypedExprPtr> sortingKeys,
std::vector<SortOrder> sortingOrders,
const std::optional<std::string>& rowNumberColumnName,
int32_t limit,
PlanNodePtr source);

#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY
/// @param partitionKeys Partitioning keys. May be empty.
/// @param sortingKeys Sorting keys. May not be empty and may not intersect
/// with 'partitionKeys'.
/// @param sortingOrders Sorting orders, one per sorting key.
/// @param rowNumberColumnName Optional name of the column containing row
/// numbers (or rank and dense_rank). If not specified, the output doesn't
/// include 'row number' column. This is used when computing partial results.
/// @param limit Per-partition limit. The number of
/// rows produced by this node will not exceed this value for any given
/// partition. Extra rows will be dropped.
TopNRowNumberNode(
PlanNodeId id,
std::vector<FieldAccessTypedExprPtr> partitionKeys,
std::vector<FieldAccessTypedExprPtr> sortingKeys,
std::vector<SortOrder> sortingOrders,
const std::optional<std::string>& rowNumberColumnName,
int32_t limit,
PlanNodePtr source)
: TopNRowNumberNode(
id,
RankFunction::kRowNumber,
partitionKeys,
sortingKeys,
sortingOrders,
rowNumberColumnName,
limit,
source) {}
#endif

const std::vector<PlanNodePtr>& sources() const override {
return sources_;
}
Expand Down Expand Up @@ -2396,6 +2438,10 @@ class TopNRowNumberNode : public PlanNode {
return limit_;
}

RankFunction rankFunction() const {
return function_;
}

bool generateRowNumber() const {
return outputType_->size() > sources_[0]->outputType()->size();
}
Expand All @@ -2411,6 +2457,8 @@ class TopNRowNumberNode : public PlanNode {
private:
void addDetails(std::stringstream& stream) const override;

const RankFunction function_;

const std::vector<FieldAccessTypedExprPtr> partitionKeys_;

const std::vector<FieldAccessTypedExprPtr> sortingKeys_;
Expand Down
33 changes: 33 additions & 0 deletions velox/exec/RowContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,22 @@ bool RowComparator::operator()(const char* lhs, const char* rhs) {
return false;
}

int32_t RowComparator::compare(const char* lhs, const char* rhs) {
if (lhs == rhs) {
return false;
}
for (auto& key : keyInfo_) {
if (auto result = rowContainer_->compare(
lhs,
rhs,
key.first,
{key.second.isNullsFirst(), key.second.isAscending(), false})) {
return result;
}
}
return 0;
}

bool RowComparator::operator()(
const std::vector<DecodedVector>& decodedVectors,
vector_size_t index,
Expand All @@ -1279,4 +1295,21 @@ bool RowComparator::operator()(
}
return false;
}

int32_t RowComparator::compare(
const std::vector<DecodedVector>& decodedVectors,
vector_size_t index,
const char* rhs) {
for (auto& key : keyInfo_) {
if (auto result = rowContainer_->compare(
rhs,
rowContainer_->columnAt(key.first),
decodedVectors[key.first],
index,
{key.second.isNullsFirst(), key.second.isAscending(), false})) {
return result;
}
}
return 0;
}
} // namespace facebook::velox::exec
10 changes: 10 additions & 0 deletions velox/exec/RowContainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -1762,12 +1762,22 @@ class RowComparator {
/// Returns true if lhs < rhs, false otherwise.
bool operator()(const char* lhs, const char* rhs);

/// Returns 0 for equal, < 0 for 'lhs' < 'rhs', > 0 otherwise.
int compare(const char* lhs, const char* rhs);

/// Returns true if decodeVectors[index] < rhs, false otherwise.
bool operator()(
const std::vector<DecodedVector>& decodedVectors,
vector_size_t index,
const char* rhs);

/// Returns 0 for equal, < 0 for 'decodedVectors[index]' < 'rhs',
/// > 0 otherwise.
int32_t compare(
const std::vector<DecodedVector>& decodedVectors,
vector_size_t index,
const char* rhs);

private:
std::vector<std::pair<column_index_t, core::SortOrder>> keyInfo_;
RowContainer* rowContainer_;
Expand Down
Loading

0 comments on commit c7926e5

Please sign in to comment.