Skip to content

Commit

Permalink
Add ntile() spark window function (facebookincubator#8597)
Browse files Browse the repository at this point in the history
Summary:
Spark [Ntile function](https://github.com/apache/spark/blob/f824d058b14e3c58b1c90f64fefc45fac105c7dd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala#L842) divides the rows for each window partition into n buckets ranging from 1 to at most n. Bucket values will differ by at most 1. If the number of rows in the partition does not divide evenly into the number of buckets, then the remainder values are distributed one per bucket, starting with the first bucket. The difference between sparksql and prestosql is the return type, where the sparksql's return type is integer and the prestosql's is bigint. This PR refer the [nth_value()](https://github.com/facebookincubator/velox/blob/b9be1718a70f3f81d184cd1dc57134552a2ed96a/velox/functions/lib/window/NthValue.h#L20) function and move the Ntile.cpp file from
 `velox/functions/prestosql/window into the velox/functions/window`.
 And also provide `registerNtileBigint `and `registerNtileInteger `for prestosql and sparksql.

Pull Request resolved: facebookincubator#8597

Reviewed By: Yuhta

Differential Revision: D53517706

Pulled By: mbasmanova

fbshipit-source-id: 77b3ca5863a337021013f3aaa0ed932a68de862e
  • Loading branch information
JkSelf authored and facebook-github-bot committed Feb 7, 2024
1 parent 8f5f153 commit 6d56baf
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 30 deletions.
5 changes: 5 additions & 0 deletions velox/docs/functions/spark/window.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ Returns the rank of a value in a group of values. The rank is one plus the numbe
Returns the rank of a value in a group of values. This is similar to rank(), except that tie values do not produce gaps in the sequence.

.. spark:function:: ntile(n) -> integer
Divides the rows for each window partition into n buckets ranging from 1 to at most ``n``. Bucket values will differ by at most 1. If the number of rows in the partition does not divide evenly into the number of buckets, then the remainder values are distributed one per bucket, starting with the first bucket.

For example, with 6 rows and 4 buckets, the bucket values would be as follows: ``1 1 2 2 3 4``
3 changes: 2 additions & 1 deletion velox/functions/lib/window/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

add_library(velox_functions_window NthValue.cpp Rank.cpp RowNumber.cpp)
add_library(velox_functions_window NthValue.cpp Rank.cpp RowNumber.cpp
Ntile.cpp)

target_link_libraries(velox_functions_window velox_buffer velox_exec
Folly::folly)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,32 @@
#include "velox/expression/FunctionSignature.h"
#include "velox/vector/FlatVector.h"

namespace facebook::velox::window::prestosql {
namespace facebook::velox::functions::window {

namespace {

template <typename TResult>
class NtileFunction : public exec::WindowFunction {
public:
explicit NtileFunction(
const std::vector<exec::WindowFunctionArg>& args,
const TypePtr& resultType,
velox::memory::MemoryPool* pool)
: WindowFunction(BIGINT(), pool, nullptr) {
: WindowFunction(resultType, pool, nullptr) {
if (args[0].constantValue) {
auto argBuckets = args[0].constantValue;
if (!argBuckets->isNullAt(0)) {
numFixedBuckets_ =
argBuckets->as<ConstantVector<int64_t>>()->valueAt(0);
argBuckets->as<ConstantVector<TResult>>()->valueAt(0);
VELOX_USER_CHECK_GE(
numFixedBuckets_.value(), 1, "{}", kBucketErrorString);
}
return;
}

bucketColumn_ = args[0].index;
bucketVector_ = BaseVector::create(BIGINT(), 0, pool);
bucketFlatVector_ = bucketVector_->asFlatVector<int64_t>();
bucketVector_ = BaseVector::create(resultType_, 0, pool);
bucketFlatVector_ = bucketVector_->asFlatVector<TResult>();
}

void resetPartition(const exec::WindowPartition* partition) override {
Expand Down Expand Up @@ -87,21 +89,21 @@ class NtileFunction : public exec::WindowFunction {
struct BucketMetrics {
// To compute the bucket number for a row, we find the number of rows in
// a bucket as the (number of rows in partition) / (number of buckets).
int64_t rowsPerBucket;
TResult rowsPerBucket;
// There could be some buckets with rowsPerBucket + 1 number of rows,
// as the partition rows might not be exactly divisible
// by the number of buckets. There are
// (number of rows in partition) % (number of buckets) such buckets.
int64_t bucketsWithExtraRow;
TResult bucketsWithExtraRow;
// When assigning bucket numbers, the first 'bucketsWithExtraRow' buckets
// will have (rowsPerBucket + 1) rows. This row number at this boundary is
// extraBucketsBoundary = bucketsWithExtraRow * (rowsPerBucket + 1). Beyond
// this row number in the partition, the buckets will have only
// rowsPerBucket number of rows. This boundary is useful when computing the
// bucket value.
int64_t extraBucketsBoundary;
TResult extraBucketsBoundary;

int64_t computeBucketValue(vector_size_t rowNumber) const {
TResult computeBucketValue(vector_size_t rowNumber) const {
if (rowNumber < extraBucketsBoundary) {
return rowNumber / (rowsPerBucket + 1) + 1;
}
Expand All @@ -115,7 +117,7 @@ class NtileFunction : public exec::WindowFunction {
vector_size_t numRows,
int64_t partitionOffset,
vector_size_t resultOffset,
int64_t* rawResultValues) {
TResult* rawResultValues) {
int64_t i = 0;
// This loop terminates if it reaches extraBucketBoundary or numRows
// in the result vector are filled.
Expand All @@ -130,7 +132,7 @@ class NtileFunction : public exec::WindowFunction {
}
};

BucketMetrics computeBucketMetrics(int64_t numBuckets) const {
BucketMetrics computeBucketMetrics(TResult numBuckets) const {
auto rowsPerBucket = numPartitionRows_ / numBuckets;
auto bucketsWithExtraRow = numPartitionRows_ % numBuckets;
auto extraBucketsBoundary = (rowsPerBucket + 1) * bucketsWithExtraRow;
Expand All @@ -145,7 +147,7 @@ class NtileFunction : public exec::WindowFunction {
partition_->extractColumn(
bucketColumn_.value(), partitionOffset_, numRows, 0, bucketVector_);

auto* resultFlatVector = result->asFlatVector<int64_t>();
auto* resultFlatVector = result->asFlatVector<TResult>();
auto* rawValues = resultFlatVector->mutableRawValues();
for (auto i = 0; i < numRows; i++) {
if (bucketFlatVector_->isNullAt(i)) {
Expand All @@ -170,7 +172,7 @@ class NtileFunction : public exec::WindowFunction {
vector_size_t resultOffset,
const VectorPtr& result) {
if (numFixedBuckets_.has_value()) {
auto rawValues = result->asFlatVector<int64_t>()->mutableRawValues();
auto rawValues = result->asFlatVector<TResult>()->mutableRawValues();
if (fixedBucketsMoreThanPartition_) {
std::iota(
rawValues + resultOffset,
Expand All @@ -183,7 +185,7 @@ class NtileFunction : public exec::WindowFunction {
} else {
// This is a function call with a constant null value. Set all result
// rows to null.
auto* resultVector = result->asFlatVector<int64_t>();
auto* resultVector = result->asFlatVector<TResult>();
auto mutableRawNulls = resultVector->mutableRawNulls();
bits::fillBits(
mutableRawNulls, resultOffset, resultOffset + numRows, bits::kNull);
Expand All @@ -195,7 +197,7 @@ class NtileFunction : public exec::WindowFunction {

// Number of buckets if a constant value. Is optional as the value could
// be null.
std::optional<int64_t> numFixedBuckets_;
std::optional<TResult> numFixedBuckets_;

// If number of buckets is greater than the partition rows, then the output
// bucket number is simply row number + 1. So bucket computation can be
Expand All @@ -209,29 +211,30 @@ class NtileFunction : public exec::WindowFunction {

// Current WindowPartition used for accessing rows in the apply method.
const exec::WindowPartition* partition_;
int64_t numPartitionRows_ = 0;
TResult numPartitionRows_ = 0;

// Denotes how far along the partition rows are output already.
int64_t partitionOffset_ = 0;

// Vector used to read the bucket column values.
VectorPtr bucketVector_;
FlatVector<int64_t>* bucketFlatVector_;
FlatVector<TResult>* bucketFlatVector_;

static const std::string kBucketErrorString;
};

const std::string NtileFunction::kBucketErrorString =
template <typename TResult>
const std::string NtileFunction<TResult>::kBucketErrorString =
"Buckets must be greater than 0";

} // namespace

void registerNtile(const std::string& name) {
// ntile(bigint) -> bigint.
template <typename TResult>
void registerNtile(const std::string& name, const std::string& type) {
std::vector<exec::FunctionSignaturePtr> signatures{
exec::FunctionSignatureBuilder()
.returnType("bigint")
.argumentType("bigint")
.returnType(type)
.argumentType(type)
.build(),
};

Expand All @@ -240,13 +243,21 @@ void registerNtile(const std::string& name) {
std::move(signatures),
[name](
const std::vector<exec::WindowFunctionArg>& args,
const TypePtr& /*resultType*/,
const TypePtr& resultType,
bool /*ignoreNulls*/,
velox::memory::MemoryPool* pool,
HashStringAllocator* /*stringAllocator*/,
const core::QueryConfig& /*queryConfig*/)
-> std::unique_ptr<exec::WindowFunction> {
return std::make_unique<NtileFunction>(args, pool);
return std::make_unique<NtileFunction<TResult>>(args, resultType, pool);
});
}
} // namespace facebook::velox::window::prestosql

void registerNtileBigint(const std::string& name) {
registerNtile<int64_t>(name, "bigint");
}
void registerNtileInteger(const std::string& name) {
registerNtile<int32_t>(name, "integer");
}

} // namespace facebook::velox::functions::window
8 changes: 8 additions & 0 deletions velox/functions/lib/window/RegistrationFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,12 @@ void registerDenseRankInteger(const std::string& name);
// Returns the percentage ranking of a value in a group of values.
void registerPercentRank(const std::string& name);

// Register the Presto function ntile() with the bigint data type
// for the return and input value.
void registerNtileBigint(const std::string& name);

// Register the Spark function ntile() with the integer data type
// for the return and input value.
void registerNtileInteger(const std::string& name);

} // namespace facebook::velox::functions::window
2 changes: 1 addition & 1 deletion velox/functions/prestosql/window/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ if(${VELOX_BUILD_TESTING})
add_subdirectory(tests)
endif()

add_library(velox_window CumeDist.cpp FirstLastValue.cpp LeadLag.cpp Ntile.cpp
add_library(velox_window CumeDist.cpp FirstLastValue.cpp LeadLag.cpp
WindowFunctionsRegistration.cpp)

target_link_libraries(velox_window velox_buffer velox_exec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace facebook::velox::window {
namespace prestosql {

extern void registerCumeDist(const std::string& name);
extern void registerNtile(const std::string& name);
extern void registerNtileBigint(const std::string& name);
extern void registerFirstValue(const std::string& name);
extern void registerLastValue(const std::string& name);
extern void registerLag(const std::string& name);
Expand All @@ -33,7 +33,7 @@ void registerAllWindowFunctions(const std::string& prefix) {
functions::window::registerDenseRankBigint(prefix + "dense_rank");
functions::window::registerPercentRank(prefix + "percent_rank");
registerCumeDist(prefix + "cume_dist");
registerNtile(prefix + "ntile");
functions::window::registerNtileBigint(prefix + "ntile");
functions::window::registerNthValueBigint(prefix + "nth_value");
registerFirstValue(prefix + "first_value");
registerLastValue(prefix + "last_value");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ void registerWindowFunctions(const std::string& prefix) {
functions::window::registerRowNumberInteger(prefix + "row_number");
functions::window::registerRankInteger(prefix + "rank");
functions::window::registerDenseRankInteger(prefix + "dense_rank");
functions::window::registerNtileInteger(prefix + "ntile");
}

} // namespace facebook::velox::functions::window::sparksql
7 changes: 6 additions & 1 deletion velox/functions/sparksql/window/tests/SparkWindowTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ static const std::vector<std::string> kSparkWindowFunctions = {
std::string("nth_value(c0, c3)"),
std::string("row_number()"),
std::string("rank()"),
std::string("dense_rank()")};
std::string("dense_rank()"),
std::string("ntile(c3)"),
std::string("ntile(1)"),
std::string("ntile(7)"),
std::string("ntile(10)"),
std::string("ntile(16)")};

struct SparkWindowTestParam {
const std::string function;
Expand Down

0 comments on commit 6d56baf

Please sign in to comment.