Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(hashjoin): Turn off dynamic filter push downs for null aware right semi porject join #11781

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions velox/exec/HashProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,8 @@ void HashProbe::asyncWaitForHashTable() {
}
} else if (
(isInnerJoin(joinType_) || isLeftSemiFilterJoin(joinType_) ||
isRightSemiFilterJoin(joinType_) || isRightSemiProjectJoin(joinType_)) &&
isRightSemiFilterJoin(joinType_) ||
(isRightSemiProjectJoin(joinType_) && !nullAware_)) &&
table_->hashMode() != BaseHashTable::HashMode::kHash && !isSpillInput() &&
!hasMoreSpillData()) {
// Find out whether there are any upstream operators that can accept dynamic
Expand All @@ -443,13 +444,9 @@ void HashProbe::asyncWaitForHashTable() {
const auto channels = operatorCtx_->driverCtx()->driver->canPushdownFilters(
this, keyChannels_);

// Null aware Right Semi Project join needs to know whether there are any
// nulls on the probe side. Hence, cannot filter these out.
const auto nullAllowed = isRightSemiProjectJoin(joinType_) && nullAware_;

for (auto i = 0; i < keyChannels_.size(); ++i) {
if (channels.find(keyChannels_[i]) != channels.end()) {
if (auto filter = buildHashers[i]->getFilter(nullAllowed)) {
if (auto filter = buildHashers[i]->getFilter(/*nullAllowed=*/false)) {
dynamicFilters_.emplace(keyChannels_[i], std::move(filter));
}
}
Expand Down Expand Up @@ -1220,21 +1217,24 @@ void HashProbe::prepareFilterRowsForNullAwareJoin(
filterInputColumnDecodedVector_.decode(
*filterInput->childAt(projection.outputChannel), filterInputRows_);
if (filterInputColumnDecodedVector_.mayHaveNulls()) {
SelectivityVector nullsInActiveRows(numRows);
memcpy(
nullsInActiveRows.asMutableRange().bits(),
filterInputColumnDecodedVector_.nulls(&filterInputRows_),
bits::nbytes(numRows));
// All rows that are not active count as non-null here.
bits::orWithNegatedBits(
nullsInActiveRows.asMutableRange().bits(),
filterInputRows_.asRange().bits(),
0,
numRows);
// NOTE: the false value of a raw null bit indicates null so we OR
// with negative of the raw bit.
bits::orWithNegatedBits(
rawNullRows, nullsInActiveRows.asRange().bits(), 0, numRows);
if (const uint64_t* nulls =
filterInputColumnDecodedVector_.nulls(&filterInputRows_)) {
SelectivityVector nullsInActiveRows(numRows);
memcpy(
nullsInActiveRows.asMutableRange().bits(),
nulls,
bits::nbytes(numRows));
// All rows that are not active count as non-null here.
bits::orWithNegatedBits(
nullsInActiveRows.asMutableRange().bits(),
filterInputRows_.asRange().bits(),
0,
numRows);
// NOTE: the false value of a raw null bit indicates null so we OR
// with negative of the raw bit.
bits::orWithNegatedBits(
rawNullRows, nullsInActiveRows.asRange().bits(), 0, numRows);
}
}
}
nullFilterInputRows_.updateBounds();
Expand Down
46 changes: 36 additions & 10 deletions velox/exec/fuzzer/DuckQueryRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,12 @@ std::optional<std::string> DuckQueryRunner::toSql(
return out.str();
};

const auto& equiClausesToSql = [](auto joinNode) {
const auto filterToSql = [](core::TypedExprPtr filter) {
auto call = std::dynamic_pointer_cast<const core::CallTypedExpr>(filter);
return toCallSql(call);
};

const auto& joinConditionAsSql = [&](auto joinNode) {
std::stringstream out;
for (auto i = 0; i < joinNode->leftKeys().size(); ++i) {
if (i > 0) {
Expand All @@ -363,6 +368,9 @@ std::optional<std::string> DuckQueryRunner::toSql(
out << joinNode->leftKeys()[i]->name() << " = "
<< joinNode->rightKeys()[i]->name();
}
if (joinNode->filter()) {
out << " AND " << filterToSql(joinNode->filter());
}
return out.str();
};

Expand All @@ -378,39 +386,56 @@ std::optional<std::string> DuckQueryRunner::toSql(

switch (joinNode->joinType()) {
case core::JoinType::kInner:
sql << " FROM t INNER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeft:
sql << " FROM t LEFT JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kFull:
sql << " FROM t FULL OUTER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeftSemiFilter:
// Multiple columns returned by a scalar subquery is not supported in
// DuckDB. A scalar subquery expression is a subquery that returns one
// result row from exactly one column for every input row.
if (joinNode->leftKeys().size() > 1) {
return std::nullopt;
}
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
<< " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ")";
break;
case core::JoinType::kLeftSemiProject:
if (joinNode->isNullAware()) {
sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT "
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u) FROM t";
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ") FROM t";
} else {
sql << ", EXISTS (SELECT * FROM u WHERE " << equiClausesToSql(joinNode)
<< ") FROM t";
sql << ", EXISTS (SELECT * FROM u WHERE "
<< joinConditionAsSql(joinNode);
sql << ") FROM t";
}
break;
case core::JoinType::kAnti:
if (joinNode->isNullAware()) {
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
<< " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ")";
} else {
sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE "
<< equiClausesToSql(joinNode) << ")";
<< joinConditionAsSql(joinNode);
sql << ")";
}
break;
default:
Expand All @@ -424,6 +449,7 @@ std::optional<std::string> DuckQueryRunner::toSql(
std::optional<std::string> DuckQueryRunner::toSql(
const std::shared_ptr<const core::NestedLoopJoinNode>& joinNode) {
std::stringstream sql;
sql << "SELECT " << folly::join(", ", joinNode->outputType()->names());

// Nested loop join without filter.
VELOX_CHECK(
Expand Down
47 changes: 36 additions & 11 deletions velox/exec/fuzzer/PrestoQueryRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,12 @@ std::optional<std::string> PrestoQueryRunner::toSql(
return out.str();
};

const auto equiClausesToSql = [](auto joinNode) {
const auto filterToSql = [](core::TypedExprPtr filter) {
auto call = std::dynamic_pointer_cast<const core::CallTypedExpr>(filter);
return toCallSql(call);
};

const auto& joinConditionAsSql = [&](auto joinNode) {
std::stringstream out;
for (auto i = 0; i < joinNode->leftKeys().size(); ++i) {
if (i > 0) {
Expand All @@ -584,6 +589,9 @@ std::optional<std::string> PrestoQueryRunner::toSql(
out << joinNode->leftKeys()[i]->name() << " = "
<< joinNode->rightKeys()[i]->name();
}
if (joinNode->filter()) {
out << " AND " << filterToSql(joinNode->filter());
}
return out.str();
};

Expand All @@ -599,52 +607,69 @@ std::optional<std::string> PrestoQueryRunner::toSql(

switch (joinNode->joinType()) {
case core::JoinType::kInner:
sql << " FROM t INNER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeft:
sql << " FROM t LEFT JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kFull:
sql << " FROM t FULL OUTER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeftSemiFilter:
// Multiple columns returned by a scalar subquery is not supported in
// Presto. A scalar subquery expression is a subquery that returns one
// result row from exactly one column for every input row.
if (joinNode->leftKeys().size() > 1) {
return std::nullopt;
}
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
<< " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ")";
break;
case core::JoinType::kLeftSemiProject:
if (joinNode->isNullAware()) {
sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT "
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u) FROM t";
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ") FROM t";
} else {
sql << ", EXISTS (SELECT * FROM u WHERE " << equiClausesToSql(joinNode)
<< ") FROM t";
sql << ", EXISTS (SELECT * FROM u WHERE "
<< joinConditionAsSql(joinNode);
sql << ") FROM t";
}
break;
case core::JoinType::kAnti:
if (joinNode->isNullAware()) {
sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys())
<< " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys())
<< " FROM u)";
<< " FROM u";
if (joinNode->filter()) {
sql << " WHERE " << filterToSql(joinNode->filter());
}
sql << ")";
} else {
sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE "
<< equiClausesToSql(joinNode) << ")";
<< joinConditionAsSql(joinNode);
sql << ")";
}
break;
default:
VELOX_UNREACHABLE(
"Unknown join type: {}", static_cast<int>(joinNode->joinType()));
}

return sql.str();
}

std::optional<std::string> PrestoQueryRunner::toSql(
const std::shared_ptr<const core::NestedLoopJoinNode>& joinNode) {
std::stringstream sql;
sql << "SELECT " << folly::join(", ", joinNode->outputType()->names());

// Nested loop join without filter.
VELOX_CHECK(
Expand Down
102 changes: 60 additions & 42 deletions velox/exec/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3294,62 +3294,80 @@ TEST_P(MultiThreadedHashJoinTest, noSpillLevelLimit) {
.run();
}

// Verify that dynamic filter pushed down from null-aware right semi project
// join into table scan doesn't filter out nulls.
// Verify that dynamic filter pushed down is turned off for null-aware right
// semi project join.
TEST_F(HashJoinTest, nullAwareRightSemiProjectOverScan) {
auto probe = makeRowVector(
std::vector<RowVectorPtr> probes;
std::vector<RowVectorPtr> builds;
// Matches present:
probes.push_back(makeRowVector(
{"t0"},
{
makeNullableFlatVector<int32_t>({1, std::nullopt, 2}),
});
}));
builds.push_back(makeRowVector(
{"u0"},
{
makeNullableFlatVector<int32_t>({1, 2, 3, std::nullopt}),
}));

auto build = makeRowVector(
// No matches present:
probes.push_back(makeRowVector(
{"t0"},
{
makeFlatVector<int32_t>({5, 6}),
}));
builds.push_back(makeRowVector(
{"u0"},
{
makeNullableFlatVector<int32_t>({1, 2, 3, std::nullopt}),
});
}));

std::shared_ptr<TempFilePath> probeFile = TempFilePath::create();
writeToFile(probeFile->getPath(), {probe});
for (int i = 0; i < probes.size(); i++) {
RowVectorPtr& probe = probes[i];
RowVectorPtr& build = builds[i];
std::shared_ptr<TempFilePath> probeFile = TempFilePath::create();
writeToFile(probeFile->getPath(), {probe});

std::shared_ptr<TempFilePath> buildFile = TempFilePath::create();
writeToFile(buildFile->getPath(), {build});
std::shared_ptr<TempFilePath> buildFile = TempFilePath::create();
writeToFile(buildFile->getPath(), {build});

createDuckDbTable("t", {probe});
createDuckDbTable("u", {build});
createDuckDbTable("t", {probe});
createDuckDbTable("u", {build});

core::PlanNodeId probeScanId;
core::PlanNodeId buildScanId;
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan = PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(probe->type()))
.capturePlanNodeId(probeScanId)
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(build->type()))
.capturePlanNodeId(buildScanId)
.planNode(),
"",
{"u0", "match"},
core::JoinType::kRightSemiProject,
true /*nullAware*/)
.planNode();
core::PlanNodeId probeScanId;
core::PlanNodeId buildScanId;
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan = PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(probe->type()))
.capturePlanNodeId(probeScanId)
.hashJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator)
.tableScan(asRowType(build->type()))
.capturePlanNodeId(buildScanId)
.planNode(),
"",
{"u0", "match"},
core::JoinType::kRightSemiProject,
true /*nullAware*/)
.planNode();

SplitInput splitInput = {
{probeScanId,
{exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}},
{buildScanId,
{exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}},
};
SplitInput splitInput = {
{probeScanId,
{exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}},
{buildScanId,
{exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}},
};

HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
.planNode(plan)
.inputSplits(splitInput)
.checkSpillStats(false)
.referenceQuery("SELECT u0, u0 IN (SELECT t0 FROM t) FROM u")
.run();
HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
.planNode(plan)
.inputSplits(splitInput)
.checkSpillStats(false)
.referenceQuery("SELECT u0, u0 IN (SELECT t0 FROM t) FROM u")
.run();
}
}

TEST_F(HashJoinTest, duplicateJoinKeys) {
Expand Down
Loading