Skip to content

Commit

Permalink
Fix integer overflow for window ROWS frame (facebookincubator#8870)
Browse files Browse the repository at this point in the history
Summary:
For window ROWS frame, if a large preceding/following value (int32/int64) is used, integer can overflow during the computation for `rawFrameBounds` (int32), which produces unexpected frame and then wrong result.

Pull Request resolved: facebookincubator#8870

Reviewed By: amitkdutta

Differential Revision: D55287372

Pulled By: kagamiori

fbshipit-source-id: f7426fb1f3c7939165176071f12637cb41516a3f
  • Loading branch information
PHILO-HE authored and facebook-github-bot committed Mar 26, 2024
1 parent 4f3d32f commit 0618c7f
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 5 deletions.
1 change: 0 additions & 1 deletion velox/core/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -2084,7 +2084,6 @@ class WindowNode : public PlanNode {
/// Frame bounds can be CURRENT ROW, UNBOUNDED PRECEDING(FOLLOWING)
/// and k PRECEDING(FOLLOWING). K could be a constant or column.
///
/// k PRECEDING(FOLLOWING) is only supported for ROW frames now.
/// k has to be of integer or bigint type.
struct Frame {
WindowType type;
Expand Down
43 changes: 40 additions & 3 deletions velox/exec/Window.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,16 @@ void updateKRowsOffsetsColumn(
// moves ahead.
int precedingFactor = isKPreceding ? -1 : 1;
for (auto i = 0; i < numRows; i++) {
rawFrameBounds[i] =
(startRow + i) + vector_size_t(precedingFactor * offsets[i]);
auto startValue = (int64_t)(startRow + i) + precedingFactor * offsets[i];
if (startValue < INT32_MIN) {
rawFrameBounds[i] = 0;
} else if (startValue > INT32_MAX) {
// computeValidFrames will replace INT32_MAX set here
// with partition's final row index.
rawFrameBounds[i] = INT32_MAX;
} else {
rawFrameBounds[i] = startValue;
}
}
}

Expand All @@ -296,7 +304,36 @@ void Window::updateKRowsFrameBounds(
if (frameArg.index == kConstantChannel) {
auto constantOffset = frameArg.constant.value();
auto startValue =
startRow + (isKPreceding ? -constantOffset : constantOffset);
(int64_t)startRow + (isKPreceding ? -constantOffset : constantOffset);

if (isKPreceding) {
if (startValue < INT32_MIN) {
// For overflow in kPreceding frames, k < INT32_MIN. Since the max
// number of rows in a partition is INT32_MAX, the frameBound will
// always be bound to the first row of the partition
std::fill_n(rawFrameBounds, numRows, 0);
return;
}
std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue);
return;
}

// KFollowing.
// The start index that overflow happens.
int32_t overflowStart;
if (startValue > (int64_t)INT32_MAX) {
overflowStart = 0;
} else {
overflowStart = INT32_MAX - startValue + 1;
}
if (overflowStart >= 0 && overflowStart < numRows) {
std::iota(rawFrameBounds, rawFrameBounds + overflowStart, startValue);
// For remaining rows that overflow happens, use INT32_MAX.
// computeValidFrames will replace it with partition's final row index.
std::fill_n(
rawFrameBounds + overflowStart, numRows - overflowStart, INT32_MAX);
return;
}
std::iota(rawFrameBounds, rawFrameBounds + numRows, startValue);
} else {
currentPartition_->extractColumn(
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/lib/window/tests/WindowTestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class WindowTestBase : public exec::test::OperatorTestBase {
void testKRangeFrames(const std::string& function);

/// ParseOptions for the DuckDB Parser. nth_value in Spark expects to parse
/// integer as bigint vs bigint in Presto. The default is to parse integer
/// integer as int vs bigint in Presto. The default is to parse integer
/// as bigint (Presto behavior).
parse::ParseOptions options_;

Expand Down
119 changes: 119 additions & 0 deletions velox/functions/prestosql/window/tests/AggregateWindowTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,124 @@ TEST_F(AggregateWindowTest, testDecimal) {
testAggregate(DECIMAL(5, 2));
testAggregate(DECIMAL(20, 5));
}

TEST_F(AggregateWindowTest, integerOverflowRowsFrame) {
auto c0 = makeFlatVector<int64_t>({-1, -1, -1, -1, -1, -1, 2, 2, 2, 2});
auto c1 = makeFlatVector<double>({-1, -2, -3, -4, -5, -6, -7, -8, -9, -10});
// INT32_MAX: 2147483647
auto c2 = makeFlatVector<int32_t>(
{1,
2147483647,
2147483646,
2147483645,
1,
10,
1,
2147483647,
2147483646,
2147483645});
auto c3 = makeFlatVector<int64_t>(
{2147483651,
1,
2147483650,
10,
2147483648,
2147483647,
2,
2147483646,
2147483650,
2147483648});
auto input = makeRowVector({c0, c1, c2, c3});
std::string overClause = "partition by c0 order by c1 desc";

// Constant following larger than INT32_MAX (2147483647).
std::string frameClause = "rows between 0 preceding and 2147483650 following";
auto expected = makeRowVector(
{c0,
c1,
c2,
c3,
makeFlatVector<int64_t>({6, 5, 4, 3, 2, 1, 4, 3, 2, 1})});
WindowTestBase::testWindowFunction(
{input}, "count(c1)", overClause, frameClause, expected);

// Overflow starts happening from middle of the partition.
frameClause = "rows between 0 preceding and 2147483645 following";
expected = makeRowVector(
{c0,
c1,
c2,
c3,
makeFlatVector<int64_t>({6, 5, 4, 3, 2, 1, 4, 3, 2, 1})});
WindowTestBase::testWindowFunction(
{input}, "count(c1)", overClause, frameClause, expected);

// Column-specified following (int32).
frameClause = "rows between 0 preceding and c2 following";
expected = makeRowVector(
{c0,
c1,
c2,
c3,
makeFlatVector<int64_t>({2, 5, 4, 3, 2, 1, 2, 3, 2, 1})});
WindowTestBase::testWindowFunction(
{input}, "count(c1)", overClause, frameClause, expected);

// Column-specified following (int64).
frameClause = "rows between 0 preceding and c3 following";
expected = makeRowVector(
{c0,
c1,
c2,
c3,
makeFlatVector<int64_t>({6, 2, 4, 3, 2, 1, 3, 3, 2, 1})});
WindowTestBase::testWindowFunction(
{input}, "count(c1)", overClause, frameClause, expected);

// Constant preceding larger than INT32_MAX.
frameClause = "rows between 2147483650 preceding and 0 following";
expected = makeRowVector(
{c0,
c1,
c2,
c3,
makeFlatVector<int64_t>({1, 2, 3, 4, 5, 6, 1, 2, 3, 4})});
WindowTestBase::testWindowFunction(
{input}, "count(c1)", overClause, frameClause, expected);

// Column-specified preceding (int32).
frameClause = "rows between c2 preceding and 0 following";
expected = makeRowVector(
{c0,
c1,
c2,
c3,
makeFlatVector<int64_t>({1, 2, 3, 4, 2, 6, 1, 2, 3, 4})});
WindowTestBase::testWindowFunction(
{input}, "count(c1)", overClause, frameClause, expected);

// Column-specified preceding (int64).
frameClause = "rows between c3 preceding and 0 following";
expected = makeRowVector(
{c0,
c1,
c2,
c3,
makeFlatVector<int64_t>({1, 2, 3, 4, 5, 6, 1, 2, 3, 4})});
WindowTestBase::testWindowFunction(
{input}, "count(c1)", overClause, frameClause, expected);

// Constant preceding & following both larger than INT32_MAX.
frameClause = "rows between 2147483650 preceding and 2147483651 following";
expected = makeRowVector(
{c0,
c1,
c2,
c3,
makeFlatVector<int64_t>({6, 6, 6, 6, 6, 6, 4, 4, 4, 4})});
WindowTestBase::testWindowFunction(
{input}, "count(c1)", overClause, frameClause, expected);
}

}; // namespace
}; // namespace facebook::velox::window::test

0 comments on commit 0618c7f

Please sign in to comment.