From c4d3f6f54a1346fce4022d0b381421b5130a0cd5 Mon Sep 17 00:00:00 2001 From: youxiduo Date: Thu, 21 Mar 2024 10:57:50 -0700 Subject: [PATCH] Fix lead/lag for zero offset (#9026) Summary: The lead and lag return incorrect result if the offset is zero and ignore nulls is true. The reason is that, the `rowNumberIgnoreNull` assumes the offset is bigger than zero. This pr does two changes: - if the offset is constant, then make a fast path to return the rowNumbers - if the offset is not constant, then add a pre-condition check if the offset is 0 and return current row Pull Request resolved: https://github.com/facebookincubator/velox/pull/9026 Reviewed By: Yuhta Differential Revision: D55158362 Pulled By: kagamiori fbshipit-source-id: 28082a1fee98c372f672bbd64db10af40fa685c8 --- velox/functions/prestosql/window/LeadLag.cpp | 15 ++++++++++ .../prestosql/window/tests/LeadLagTest.cpp | 29 +++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/velox/functions/prestosql/window/LeadLag.cpp b/velox/functions/prestosql/window/LeadLag.cpp index 41b342072465..8e4c598c4d63 100644 --- a/velox/functions/prestosql/window/LeadLag.cpp +++ b/velox/functions/prestosql/window/LeadLag.cpp @@ -111,6 +111,9 @@ class LeadLagFunction : public exec::WindowFunction { constantOffset->as>()->valueAt(0); VELOX_USER_CHECK_GE( constantOffset_.value(), 0, "Offset must be at least 0"); + if (constantOffset_.value() == 0) { + isConstantOffsetZero_ = true; + } } } else { offsetIndex_ = offsetArg.index.value(); @@ -144,6 +147,11 @@ class LeadLagFunction : public exec::WindowFunction { std::fill(rowNumbers_.begin(), rowNumbers_.end(), kNullRow); return; } + // If the offset is 0 then it means always return the current row. + if (isConstantOffsetZero_) { + std::iota(rowNumbers_.begin(), rowNumbers_.end(), partitionOffset_); + return; + } auto constantOffsetValue = constantOffset_.value(); // Set row number to kDefaultValueRow for out of range offset. @@ -175,6 +183,11 @@ class LeadLagFunction : public exec::WindowFunction { rowNumbers_[i] = kDefaultValueRow; continue; } + // If the offset is 0 then it means always return the current row. + if (offset == 0) { + rowNumbers_[i] = partitionOffset_ + i; + continue; + } if constexpr (isLag) { if constexpr (ignoreNulls) { @@ -204,6 +217,7 @@ class LeadLagFunction : public exec::WindowFunction { } } + // This method assumes the input offset > 0 vector_size_t rowNumberIgnoreNull( const uint64_t* rawNulls, vector_size_t offset, @@ -282,6 +296,7 @@ class LeadLagFunction : public exec::WindowFunction { // Value of the 'offset' if constant. std::optional constantOffset_; bool isConstantOffsetNull_ = false; + bool isConstantOffsetZero_ = false; // Index of the 'default_value' argument if default value is specified and not // constant. diff --git a/velox/functions/prestosql/window/tests/LeadLagTest.cpp b/velox/functions/prestosql/window/tests/LeadLagTest.cpp index eceb52c836ca..d733f20d3c5b 100644 --- a/velox/functions/prestosql/window/tests/LeadLagTest.cpp +++ b/velox/functions/prestosql/window/tests/LeadLagTest.cpp @@ -184,6 +184,35 @@ TEST_P(LeadLagTest, ignoreNullsInt64Offset) { assertResults(fn(fmt::format("c0, {} IGNORE NULLS", largeOffset))); } +TEST_P(LeadLagTest, zeroOffset) { + auto data = makeRowVector({ + // Values with null. + makeNullableFlatVector( + {1, std::nullopt, 2, std::nullopt, std::nullopt}), + // Values without null. + makeFlatVector({1, 2, 3, 4, 5}), + // Offsets. + makeFlatVector({0, 0, 0, 0, 0}), + }); + createDuckDbTable({data}); + + auto assertResults = [&](const std::string& functionSql) { + auto queryInfo = buildWindowQuery({data}, functionSql, "order by c0", ""); + SCOPED_TRACE(queryInfo.functionSql); + assertQuery(queryInfo.planNode, queryInfo.querySql); + }; + + assertResults(fn("c0, 0")); + assertResults(fn("c0, c2")); + assertResults(fn("c0, 0 IGNORE NULLS")); + assertResults(fn("c0, c2 IGNORE NULLS")); + + assertResults(fn("c1, 0")); + assertResults(fn("c1, c2")); + assertResults(fn("c1, 0 IGNORE NULLS")); + assertResults(fn("c1, c2 IGNORE NULLS")); +} + TEST_P(LeadLagTest, defaultValue) { auto data = makeRowVector({ // Values.