From 014c6fe9931f81c285d8e6adabc3a2cde920d085 Mon Sep 17 00:00:00 2001
From: Daniel Hunte <danielhunte@meta.com>
Date: Mon, 9 Dec 2024 09:57:06 -0800
Subject: [PATCH] fix(hashjoin): Turn off dynamic filter push downs for null
 aware right semi porject join (#11781)

Summary:

Currently, when there are no matches for a null aware right semi project join but the build side has nulls, the dynamic push down filter filters out all the rows from the probe side. This causes the probe side to be considered empty and therefore sets the entire match column to falses, even in rows where the match value should be null.

Reviewed By: Yuhta

Differential Revision: D66903863
---
 velox/exec/HashProbe.cpp          |   9 +--
 velox/exec/tests/HashJoinTest.cpp | 102 ++++++++++++++++++------------
 2 files changed, 63 insertions(+), 48 deletions(-)

diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp
index 60ad9d69c3296..a788cc8e2ae59 100644
--- a/velox/exec/HashProbe.cpp
+++ b/velox/exec/HashProbe.cpp
@@ -428,7 +428,8 @@ void HashProbe::asyncWaitForHashTable() {
     }
   } else if (
       (isInnerJoin(joinType_) || isLeftSemiFilterJoin(joinType_) ||
-       isRightSemiFilterJoin(joinType_) || isRightSemiProjectJoin(joinType_)) &&
+       isRightSemiFilterJoin(joinType_) ||
+       (isRightSemiProjectJoin(joinType_) && !nullAware_)) &&
       table_->hashMode() != BaseHashTable::HashMode::kHash && !isSpillInput() &&
       !hasMoreSpillData()) {
     // Find out whether there are any upstream operators that can accept dynamic
@@ -443,13 +444,9 @@ void HashProbe::asyncWaitForHashTable() {
     const auto channels = operatorCtx_->driverCtx()->driver->canPushdownFilters(
         this, keyChannels_);
 
-    // Null aware Right Semi Project join needs to know whether there are any
-    // nulls on the probe side. Hence, cannot filter these out.
-    const auto nullAllowed = isRightSemiProjectJoin(joinType_) && nullAware_;
-
     for (auto i = 0; i < keyChannels_.size(); ++i) {
       if (channels.find(keyChannels_[i]) != channels.end()) {
-        if (auto filter = buildHashers[i]->getFilter(nullAllowed)) {
+        if (auto filter = buildHashers[i]->getFilter(/*nullAllowed=*/false)) {
           dynamicFilters_.emplace(keyChannels_[i], std::move(filter));
         }
       }
diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp
index 85cea6fab72ef..3cd140ea90415 100644
--- a/velox/exec/tests/HashJoinTest.cpp
+++ b/velox/exec/tests/HashJoinTest.cpp
@@ -3294,62 +3294,80 @@ TEST_P(MultiThreadedHashJoinTest, noSpillLevelLimit) {
       .run();
 }
 
-// Verify that dynamic filter pushed down from null-aware right semi project
-// join into table scan doesn't filter out nulls.
+// Verify that dynamic filter pushed down is turned off for null-aware right
+// semi project join.
 TEST_F(HashJoinTest, nullAwareRightSemiProjectOverScan) {
-  auto probe = makeRowVector(
+  std::vector<RowVectorPtr> probes;
+  std::vector<RowVectorPtr> builds;
+  // Matches present:
+  probes.push_back(makeRowVector(
       {"t0"},
       {
           makeNullableFlatVector<int32_t>({1, std::nullopt, 2}),
-      });
+      }));
+  builds.push_back(makeRowVector(
+      {"u0"},
+      {
+          makeNullableFlatVector<int32_t>({1, 2, 3, std::nullopt}),
+      }));
 
-  auto build = makeRowVector(
+  // No matches present:
+  probes.push_back(makeRowVector(
+      {"t0"},
+      {
+          makeFlatVector<int32_t>({5, 6}),
+      }));
+  builds.push_back(makeRowVector(
       {"u0"},
       {
           makeNullableFlatVector<int32_t>({1, 2, 3, std::nullopt}),
-      });
+      }));
 
-  std::shared_ptr<TempFilePath> probeFile = TempFilePath::create();
-  writeToFile(probeFile->getPath(), {probe});
+  for (int i = 0; i < probes.size(); i++) {
+    RowVectorPtr& probe = probes[i];
+    RowVectorPtr& build = builds[i];
+    std::shared_ptr<TempFilePath> probeFile = TempFilePath::create();
+    writeToFile(probeFile->getPath(), {probe});
 
-  std::shared_ptr<TempFilePath> buildFile = TempFilePath::create();
-  writeToFile(buildFile->getPath(), {build});
+    std::shared_ptr<TempFilePath> buildFile = TempFilePath::create();
+    writeToFile(buildFile->getPath(), {build});
 
-  createDuckDbTable("t", {probe});
-  createDuckDbTable("u", {build});
+    createDuckDbTable("t", {probe});
+    createDuckDbTable("u", {build});
 
-  core::PlanNodeId probeScanId;
-  core::PlanNodeId buildScanId;
-  auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
-  auto plan = PlanBuilder(planNodeIdGenerator)
-                  .tableScan(asRowType(probe->type()))
-                  .capturePlanNodeId(probeScanId)
-                  .hashJoin(
-                      {"t0"},
-                      {"u0"},
-                      PlanBuilder(planNodeIdGenerator)
-                          .tableScan(asRowType(build->type()))
-                          .capturePlanNodeId(buildScanId)
-                          .planNode(),
-                      "",
-                      {"u0", "match"},
-                      core::JoinType::kRightSemiProject,
-                      true /*nullAware*/)
-                  .planNode();
+    core::PlanNodeId probeScanId;
+    core::PlanNodeId buildScanId;
+    auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
+    auto plan = PlanBuilder(planNodeIdGenerator)
+                    .tableScan(asRowType(probe->type()))
+                    .capturePlanNodeId(probeScanId)
+                    .hashJoin(
+                        {"t0"},
+                        {"u0"},
+                        PlanBuilder(planNodeIdGenerator)
+                            .tableScan(asRowType(build->type()))
+                            .capturePlanNodeId(buildScanId)
+                            .planNode(),
+                        "",
+                        {"u0", "match"},
+                        core::JoinType::kRightSemiProject,
+                        true /*nullAware*/)
+                    .planNode();
 
-  SplitInput splitInput = {
-      {probeScanId,
-       {exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}},
-      {buildScanId,
-       {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}},
-  };
+    SplitInput splitInput = {
+        {probeScanId,
+         {exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}},
+        {buildScanId,
+         {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}},
+    };
 
-  HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
-      .planNode(plan)
-      .inputSplits(splitInput)
-      .checkSpillStats(false)
-      .referenceQuery("SELECT u0, u0 IN (SELECT t0 FROM t) FROM u")
-      .run();
+    HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
+        .planNode(plan)
+        .inputSplits(splitInput)
+        .checkSpillStats(false)
+        .referenceQuery("SELECT u0, u0 IN (SELECT t0 FROM t) FROM u")
+        .run();
+  }
 }
 
 TEST_F(HashJoinTest, duplicateJoinKeys) {