Skip to content

Commit c4d3f6f

Browse files
ulysses-youfacebook-github-bot
authored andcommitted
Fix lead/lag for zero offset (facebookincubator#9026)
Summary: The lead and lag return incorrect result if the offset is zero and ignore nulls is true. The reason is that, the `rowNumberIgnoreNull` assumes the offset is bigger than zero. This pr does two changes: - if the offset is constant, then make a fast path to return the rowNumbers - if the offset is not constant, then add a pre-condition check if the offset is 0 and return current row Pull Request resolved: facebookincubator#9026 Reviewed By: Yuhta Differential Revision: D55158362 Pulled By: kagamiori fbshipit-source-id: 28082a1fee98c372f672bbd64db10af40fa685c8
1 parent 9b4a210 commit c4d3f6f

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

velox/functions/prestosql/window/LeadLag.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ class LeadLagFunction : public exec::WindowFunction {
111111
constantOffset->as<ConstantVector<int64_t>>()->valueAt(0);
112112
VELOX_USER_CHECK_GE(
113113
constantOffset_.value(), 0, "Offset must be at least 0");
114+
if (constantOffset_.value() == 0) {
115+
isConstantOffsetZero_ = true;
116+
}
114117
}
115118
} else {
116119
offsetIndex_ = offsetArg.index.value();
@@ -144,6 +147,11 @@ class LeadLagFunction : public exec::WindowFunction {
144147
std::fill(rowNumbers_.begin(), rowNumbers_.end(), kNullRow);
145148
return;
146149
}
150+
// If the offset is 0 then it means always return the current row.
151+
if (isConstantOffsetZero_) {
152+
std::iota(rowNumbers_.begin(), rowNumbers_.end(), partitionOffset_);
153+
return;
154+
}
147155

148156
auto constantOffsetValue = constantOffset_.value();
149157
// Set row number to kDefaultValueRow for out of range offset.
@@ -175,6 +183,11 @@ class LeadLagFunction : public exec::WindowFunction {
175183
rowNumbers_[i] = kDefaultValueRow;
176184
continue;
177185
}
186+
// If the offset is 0 then it means always return the current row.
187+
if (offset == 0) {
188+
rowNumbers_[i] = partitionOffset_ + i;
189+
continue;
190+
}
178191

179192
if constexpr (isLag) {
180193
if constexpr (ignoreNulls) {
@@ -204,6 +217,7 @@ class LeadLagFunction : public exec::WindowFunction {
204217
}
205218
}
206219

220+
// This method assumes the input offset > 0
207221
vector_size_t rowNumberIgnoreNull(
208222
const uint64_t* rawNulls,
209223
vector_size_t offset,
@@ -282,6 +296,7 @@ class LeadLagFunction : public exec::WindowFunction {
282296
// Value of the 'offset' if constant.
283297
std::optional<int64_t> constantOffset_;
284298
bool isConstantOffsetNull_ = false;
299+
bool isConstantOffsetZero_ = false;
285300

286301
// Index of the 'default_value' argument if default value is specified and not
287302
// constant.

velox/functions/prestosql/window/tests/LeadLagTest.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,35 @@ TEST_P(LeadLagTest, ignoreNullsInt64Offset) {
184184
assertResults(fn(fmt::format("c0, {} IGNORE NULLS", largeOffset)));
185185
}
186186

187+
TEST_P(LeadLagTest, zeroOffset) {
188+
auto data = makeRowVector({
189+
// Values with null.
190+
makeNullableFlatVector<int32_t>(
191+
{1, std::nullopt, 2, std::nullopt, std::nullopt}),
192+
// Values without null.
193+
makeFlatVector<int32_t>({1, 2, 3, 4, 5}),
194+
// Offsets.
195+
makeFlatVector<int64_t>({0, 0, 0, 0, 0}),
196+
});
197+
createDuckDbTable({data});
198+
199+
auto assertResults = [&](const std::string& functionSql) {
200+
auto queryInfo = buildWindowQuery({data}, functionSql, "order by c0", "");
201+
SCOPED_TRACE(queryInfo.functionSql);
202+
assertQuery(queryInfo.planNode, queryInfo.querySql);
203+
};
204+
205+
assertResults(fn("c0, 0"));
206+
assertResults(fn("c0, c2"));
207+
assertResults(fn("c0, 0 IGNORE NULLS"));
208+
assertResults(fn("c0, c2 IGNORE NULLS"));
209+
210+
assertResults(fn("c1, 0"));
211+
assertResults(fn("c1, c2"));
212+
assertResults(fn("c1, 0 IGNORE NULLS"));
213+
assertResults(fn("c1, c2 IGNORE NULLS"));
214+
}
215+
187216
TEST_P(LeadLagTest, defaultValue) {
188217
auto data = makeRowVector({
189218
// Values.

0 commit comments

Comments
 (0)