diff --git a/velox/functions/prestosql/window/LeadLag.cpp b/velox/functions/prestosql/window/LeadLag.cpp index 41b342072465..8e4c598c4d63 100644 --- a/velox/functions/prestosql/window/LeadLag.cpp +++ b/velox/functions/prestosql/window/LeadLag.cpp @@ -111,6 +111,9 @@ class LeadLagFunction : public exec::WindowFunction { constantOffset->as>()->valueAt(0); VELOX_USER_CHECK_GE( constantOffset_.value(), 0, "Offset must be at least 0"); + if (constantOffset_.value() == 0) { + isConstantOffsetZero_ = true; + } } } else { offsetIndex_ = offsetArg.index.value(); @@ -144,6 +147,11 @@ class LeadLagFunction : public exec::WindowFunction { std::fill(rowNumbers_.begin(), rowNumbers_.end(), kNullRow); return; } + // If the offset is 0 then it means always return the current row. + if (isConstantOffsetZero_) { + std::iota(rowNumbers_.begin(), rowNumbers_.end(), partitionOffset_); + return; + } auto constantOffsetValue = constantOffset_.value(); // Set row number to kDefaultValueRow for out of range offset. @@ -175,6 +183,11 @@ class LeadLagFunction : public exec::WindowFunction { rowNumbers_[i] = kDefaultValueRow; continue; } + // If the offset is 0 then it means always return the current row. + if (offset == 0) { + rowNumbers_[i] = partitionOffset_ + i; + continue; + } if constexpr (isLag) { if constexpr (ignoreNulls) { @@ -204,6 +217,7 @@ class LeadLagFunction : public exec::WindowFunction { } } + // This method assumes the input offset > 0 vector_size_t rowNumberIgnoreNull( const uint64_t* rawNulls, vector_size_t offset, @@ -282,6 +296,7 @@ class LeadLagFunction : public exec::WindowFunction { // Value of the 'offset' if constant. std::optional constantOffset_; bool isConstantOffsetNull_ = false; + bool isConstantOffsetZero_ = false; // Index of the 'default_value' argument if default value is specified and not // constant. diff --git a/velox/functions/prestosql/window/tests/LeadLagTest.cpp b/velox/functions/prestosql/window/tests/LeadLagTest.cpp index eceb52c836ca..d733f20d3c5b 100644 --- a/velox/functions/prestosql/window/tests/LeadLagTest.cpp +++ b/velox/functions/prestosql/window/tests/LeadLagTest.cpp @@ -184,6 +184,35 @@ TEST_P(LeadLagTest, ignoreNullsInt64Offset) { assertResults(fn(fmt::format("c0, {} IGNORE NULLS", largeOffset))); } +TEST_P(LeadLagTest, zeroOffset) { + auto data = makeRowVector({ + // Values with null. + makeNullableFlatVector( + {1, std::nullopt, 2, std::nullopt, std::nullopt}), + // Values without null. + makeFlatVector({1, 2, 3, 4, 5}), + // Offsets. + makeFlatVector({0, 0, 0, 0, 0}), + }); + createDuckDbTable({data}); + + auto assertResults = [&](const std::string& functionSql) { + auto queryInfo = buildWindowQuery({data}, functionSql, "order by c0", ""); + SCOPED_TRACE(queryInfo.functionSql); + assertQuery(queryInfo.planNode, queryInfo.querySql); + }; + + assertResults(fn("c0, 0")); + assertResults(fn("c0, c2")); + assertResults(fn("c0, 0 IGNORE NULLS")); + assertResults(fn("c0, c2 IGNORE NULLS")); + + assertResults(fn("c1, 0")); + assertResults(fn("c1, c2")); + assertResults(fn("c1, 0 IGNORE NULLS")); + assertResults(fn("c1, c2 IGNORE NULLS")); +} + TEST_P(LeadLagTest, defaultValue) { auto data = makeRowVector({ // Values.