Skip to content

Commit

Permalink
Fix literal handling in Window functions (apache#13428)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangfu0 authored Jun 19, 2024
1 parent bb42575 commit 55b6024
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2113,6 +2113,51 @@ public void testAggregationUDFV2()
assertEquals(row.get(0).asDouble(), 16071.0 / 2);
}

@Test
public void testWindowAggregationV2()
throws Exception {
setUseMultiStageQueryEngine(true);
String tmpTableQuery =
"select DaysSinceEpoch, count(*) as num_trips from mytable GROUP BY DaysSinceEpoch order by DaysSinceEpoch";
JsonNode tmpTableResult = postQuery(tmpTableQuery).get("resultTable").get("rows");

String query = "WITH tmp AS (\n"
+ " select count(*) as num_trips, DaysSinceEpoch from mytable GROUP BY DaysSinceEpoch\n"
+ ")\n"
+ "\n"
+ "SELECT\n"
+ " DaysSinceEpoch,\n"
+ " num_trips,\n"
+ " LAG(num_trips, 2) OVER (ORDER BY DaysSinceEpoch) AS previous_num_trips,\n"
+ " num_trips - LAG(num_trips, 2) OVER (ORDER BY DaysSinceEpoch) AS difference\n"
+ "FROM\n"
+ " tmp";
JsonNode response = postQuery(query);
JsonNode resultTable = response.get("resultTable");
assertEquals(resultTable.get("dataSchema").get("columnDataTypes").toString(),
"[\"INT\",\"LONG\",\"LONG\",\"LONG\"]");
JsonNode rows = resultTable.get("rows");
assertEquals(rows.size(), 364);
for (int i = 0; i < 2; i++) {
JsonNode row = rows.get(i);
JsonNode tmpTableRow = tmpTableResult.get(i);
assertEquals(row.size(), 4);
assertEquals(row.get(0).asInt(), tmpTableRow.get(0).asInt());
assertEquals(row.get(1).asLong(), tmpTableRow.get(1).asLong());
assertTrue(row.get(2).isNull());
assertTrue(row.get(2).isNull());
}
for (int i = 2; i < 363; i++) {
JsonNode row = rows.get(i);
assertEquals(row.size(), 4);
JsonNode tmpTableRow = tmpTableResult.get(i);
assertEquals(row.get(0).asInt(), tmpTableRow.get(0).asInt());
assertEquals(row.get(1).asLong(), tmpTableRow.get(1).asLong());
assertEquals(rows.get(i - 2).get(1).asLong(), row.get(2).asLong());
assertEquals(row.get(1).asLong() - row.get(2).asLong(), row.get(3).asLong());
}
}

@Test(dataProvider = "useBothQueryEngines")
public void testSelectionUDF(boolean useMultiStageQueryEngine)
throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilderFactory;
Expand Down Expand Up @@ -92,7 +93,7 @@ public void onMatch(RelOptRuleCall call) {
// Perform all validations
validateWindows(window);

Window.Group windowGroup = window.groups.get(0);
Window.Group windowGroup = updateLiteralArgumentsInWindowGroup(window);
if (windowGroup.keys.isEmpty() && windowGroup.orderKeys.getKeys().isEmpty()) {
// Empty OVER()
// Add a single Exchange for empty OVER() since no sort is required
Expand All @@ -111,7 +112,8 @@ public void onMatch(RelOptRuleCall call) {
PinotLogicalExchange exchange = PinotLogicalExchange.create(windowInput,
RelDistributions.hash(Collections.emptyList()));
call.transformTo(
LogicalWindow.create(window.getTraitSet(), exchange, window.constants, window.getRowType(), window.groups));
LogicalWindow.create(window.getTraitSet(), exchange, window.constants, window.getRowType(),
List.of(windowGroup)));
} else if (windowGroup.keys.isEmpty() && !windowGroup.orderKeys.getKeys().isEmpty()) {
// Only ORDER BY
// Add a LogicalSortExchange with collation on the order by key(s) and an empty hash partition key
Expand All @@ -121,7 +123,7 @@ public void onMatch(RelOptRuleCall call) {
PinotLogicalSortExchange sortExchange = PinotLogicalSortExchange.create(windowInput,
RelDistributions.hash(Collections.emptyList()), windowGroup.orderKeys, false, true);
call.transformTo(LogicalWindow.create(window.getTraitSet(), sortExchange, window.constants, window.getRowType(),
window.groups));
List.of(windowGroup)));
} else {
// All other variants
// Assess whether this is a PARTITION BY only query or not (includes queries of the type where PARTITION BY and
Expand All @@ -134,7 +136,7 @@ public void onMatch(RelOptRuleCall call) {
PinotLogicalExchange exchange = PinotLogicalExchange.create(windowInput,
RelDistributions.hash(windowGroup.keys.toList()));
call.transformTo(LogicalWindow.create(window.getTraitSet(), exchange, window.constants, window.getRowType(),
window.groups));
List.of(windowGroup)));
} else {
// PARTITION BY and ORDER BY on different key(s)
// Add a LogicalSortExchange hashed on the partition by keys and collation based on order by keys
Expand All @@ -145,11 +147,50 @@ public void onMatch(RelOptRuleCall call) {
PinotLogicalSortExchange sortExchange = PinotLogicalSortExchange.create(windowInput,
RelDistributions.hash(windowGroup.keys.toList()), windowGroup.orderKeys, false, true);
call.transformTo(LogicalWindow.create(window.getTraitSet(), sortExchange, window.constants, window.getRowType(),
window.groups));
List.of(windowGroup)));
}
}
}

private Window.Group updateLiteralArgumentsInWindowGroup(Window window) {
Window.Group oldWindowGroup = window.groups.get(0);
int windowInputSize = window.getInput().getRowType().getFieldCount();
ImmutableList<Window.RexWinAggCall> oldAggCalls = oldWindowGroup.aggCalls;
List<Window.RexWinAggCall> newAggCallWindow = new ArrayList<>(oldAggCalls.size());
boolean aggCallChanged = false;
for (Window.RexWinAggCall oldAggCall : oldAggCalls) {
boolean changed = false;
List<RexNode> oldAggCallArgList = oldAggCall.getOperands();
List<RexNode> rexList = new ArrayList<>(oldAggCallArgList.size());
for (RexNode rexNode : oldAggCallArgList) {
RexNode newRexNode = rexNode;
if (rexNode instanceof RexInputRef) {
RexInputRef inputRef = (RexInputRef) rexNode;
int inputRefIndex = inputRef.getIndex();
// If the input reference is greater than the window input size, it is a reference to the constants
if (inputRefIndex >= windowInputSize) {
newRexNode = window.constants.get(inputRefIndex - windowInputSize);
changed = true;
aggCallChanged = true;
}
}
rexList.add(newRexNode);
}
if (changed) {
newAggCallWindow.add(
new Window.RexWinAggCall((SqlAggFunction) oldAggCall.getOperator(), oldAggCall.type, rexList,
oldAggCall.ordinal, oldAggCall.distinct, oldAggCall.ignoreNulls));
} else {
newAggCallWindow.add(oldAggCall);
}
}
if (aggCallChanged) {
return new Window.Group(oldWindowGroup.keys, oldWindowGroup.isRows, oldWindowGroup.lowerBound,
oldWindowGroup.upperBound, oldWindowGroup.orderKeys, newAggCallWindow);
}
return oldWindowGroup;
}

private void validateWindows(Window window) {
int numGroups = window.groups.size();
// For Phase 1 we only handle single window groups
Expand Down

0 comments on commit 55b6024

Please sign in to comment.