From 794b8b671dfda6cae8a03383522941f63854c742 Mon Sep 17 00:00:00 2001 From: Jia Ke Date: Mon, 16 Sep 2024 17:08:39 -0700 Subject: [PATCH] Support full outer join in smj (#10247) Summary: Follow up https://github.com/facebookincubator/velox/pull/10148. Full join is the union of right join and left join. Pull Request resolved: https://github.com/facebookincubator/velox/pull/10247 Reviewed By: pedroerp Differential Revision: D62760535 Pulled By: kevinwilfong fbshipit-source-id: eb5a656872581986902a1f28ee4450492d03bbe4 --- velox/core/PlanNode.cpp | 1 + velox/exec/MergeJoin.cpp | 189 ++++++++++++++++++++++++++--- velox/exec/tests/MergeJoinTest.cpp | 156 ++++++++++++++++++++++++ 3 files changed, 328 insertions(+), 18 deletions(-) diff --git a/velox/core/PlanNode.cpp b/velox/core/PlanNode.cpp index 5d48bb3b9217..a506cb07dbc5 100644 --- a/velox/core/PlanNode.cpp +++ b/velox/core/PlanNode.cpp @@ -1130,6 +1130,7 @@ bool MergeJoinNode::isSupported(core::JoinType joinType) { case core::JoinType::kLeftSemiFilter: case core::JoinType::kRightSemiFilter: case core::JoinType::kAnti: + case core::JoinType::kFull: return true; default: diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 6daed9fa9c72..438d7cb87689 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -89,7 +89,7 @@ void MergeJoin::initialize() { initializeFilter(joinNode_->filter(), leftType, rightType); if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() || - joinNode_->isRightJoin()) { + joinNode_->isRightJoin() || joinNode_->isFullJoin()) { joinTracker_ = JoinTracker(outputBatchSize_, pool()); } } else if (joinNode_->isAntiJoin()) { @@ -207,10 +207,23 @@ int32_t MergeJoin::compare( const RowVectorPtr& otherBatch, vector_size_t otherIndex) { for (auto i = 0; i < keys.size(); ++i) { + CompareFlags compareFlags = { + .equalsOnly = true, + .nullHandlingMode = + CompareFlags::NullHandlingMode::kNullAsIndeterminate}; auto compare = batch->childAt(keys[i])->compare( - otherBatch->childAt(otherKeys[i]).get(), index, otherIndex); - if (compare != 0) { - return compare; + otherBatch->childAt(otherKeys[i]).get(), + index, + otherIndex, + compareFlags); + + // Comparing null with anything will return std::nullopt. + if (!compare.has_value()) { + // The SQL semantics of Presto and Spark will always return false if + // comparing a NULL value with any other value. + return -1; + } else if (compare.value() != 0) { + return compare.value(); } } @@ -272,7 +285,8 @@ void copyRow( void MergeJoin::addOutputRowForLeftJoin( const RowVectorPtr& left, vector_size_t leftIndex) { - VELOX_USER_CHECK(isLeftJoin(joinType_) || isAntiJoin(joinType_)); + VELOX_USER_CHECK( + isLeftJoin(joinType_) || isAntiJoin(joinType_) || isFullJoin(joinType_)); rawLeftIndices_[outputSize_] = leftIndex; for (const auto& projection : rightProjections_) { @@ -291,7 +305,7 @@ void MergeJoin::addOutputRowForLeftJoin( void MergeJoin::addOutputRowForRightJoin( const RowVectorPtr& right, vector_size_t rightIndex) { - VELOX_USER_CHECK(isRightJoin(joinType_)); + VELOX_USER_CHECK(isRightJoin(joinType_) || isFullJoin(joinType_)); rawRightIndices_[outputSize_] = rightIndex; for (const auto& projection : leftProjections_) { @@ -672,10 +686,14 @@ RowVectorPtr MergeJoin::getOutput() { } if (rightInput_) { - rightIndex_ = firstNonNull(rightInput_, rightKeys_); - if (rightIndex_ == rightInput_->size()) { - // Ran out of rows on the right side. - rightInput_ = nullptr; + if (isFullJoin(joinType_)) { + rightIndex_ = 0; + } else { + rightIndex_ = firstNonNull(rightInput_, rightKeys_); + if (rightIndex_ == rightInput_->size()) { + // Ran out of rows on the right side. + rightInput_ = nullptr; + } } } else { noMoreRightInput_ = true; @@ -731,10 +749,14 @@ RowVectorPtr MergeJoin::doGetOutput() { return nullptr; } if (rightMatch_->inputs.back() == rightInput_) { - rightIndex_ = - firstNonNull(rightInput_, rightKeys_, rightMatch_->endIndex); - if (rightIndex_ == rightInput_->size()) { - rightInput_ = nullptr; + if (isFullJoin(joinType_)) { + rightIndex_ = rightMatch_->endIndex; + } else { + rightIndex_ = + firstNonNull(rightInput_, rightKeys_, rightMatch_->endIndex); + if (rightIndex_ == rightInput_->size()) { + rightInput_ = nullptr; + } } } } else if (noMoreRightInput_) { @@ -809,6 +831,62 @@ RowVectorPtr MergeJoin::doGetOutput() { } } + if (noMoreRightInput_ && output_) { + output_->resize(outputSize_); + return std::move(output_); + } + } else if (isFullJoin(joinType_)) { + if (input_ && noMoreRightInput_) { + // If output_ is currently wrapping a different buffer, return it + // first. + if (prepareOutput(input_, nullptr)) { + output_->resize(outputSize_); + return std::move(output_); + } + while (true) { + if (outputSize_ == outputBatchSize_) { + return std::move(output_); + } + addOutputRowForLeftJoin(input_, index_); + + ++index_; + if (index_ == input_->size()) { + // Ran out of rows on the left side. + input_ = nullptr; + return nullptr; + } + } + } + + if (noMoreInput_ && output_) { + output_->resize(outputSize_); + return std::move(output_); + } + + if (rightInput_ && noMoreInput_) { + // If output_ is currently wrapping a different buffer, return it + // first. + if (prepareOutput(nullptr, rightInput_)) { + output_->resize(outputSize_); + return std::move(output_); + } + + while (true) { + if (outputSize_ == outputBatchSize_) { + return std::move(output_); + } + + addOutputRowForRightJoin(rightInput_, rightIndex_); + + ++rightIndex_; + if (rightIndex_ == rightInput_->size()) { + // Ran out of rows on the right side. + rightInput_ = nullptr; + return nullptr; + } + } + } + if (noMoreRightInput_ && output_) { output_->resize(outputSize_); return std::move(output_); @@ -833,7 +911,8 @@ RowVectorPtr MergeJoin::doGetOutput() { for (;;) { // Catch up input_ with rightInput_. while (compareResult < 0) { - if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) { + if (isLeftJoin(joinType_) || isAntiJoin(joinType_) || + isFullJoin(joinType_)) { // If output_ is currently wrapping a different buffer, return it // first. if (prepareOutput(input_, nullptr)) { @@ -860,7 +939,7 @@ RowVectorPtr MergeJoin::doGetOutput() { // Catch up rightInput_ with input_. while (compareResult > 0) { - if (isRightJoin(joinType_)) { + if (isRightJoin(joinType_) || isFullJoin(joinType_)) { // If output_ is currently wrapping a different buffer, return it // first. if (prepareOutput(nullptr, rightInput_)) { @@ -928,7 +1007,12 @@ RowVectorPtr MergeJoin::doGetOutput() { } index_ = endIndex; - rightIndex_ = firstNonNull(rightInput_, rightKeys_, endRightIndex); + if (isFullJoin(joinType_)) { + rightIndex_ = endRightIndex; + } else { + rightIndex_ = firstNonNull(rightInput_, rightKeys_, endRightIndex); + } + if (rightIndex_ == rightInput_->size()) { // Ran out of rows on the right side. rightInput_ = nullptr; @@ -952,6 +1036,8 @@ RowVectorPtr MergeJoin::doGetOutput() { RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { const auto numRows = output->size(); + RowVectorPtr fullOuterOutput = nullptr; + BufferPtr indices = allocateIndices(numRows, pool()); auto rawIndices = indices->asMutable(); vector_size_t numPassed = 0; @@ -972,7 +1058,61 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (!isAntiJoin(joinType_)) { rawIndices[numPassed++] = row; - if (!isRightJoin(joinType_)) { + if (isFullJoin(joinType_)) { + // For filtered rows, it is necessary to insert additional data + // to ensure the result set is complete. Specifically, we + // need to generate two records: one record containing the + // columns from the left table along with nulls for the + // right table, and another record containing the columns + // from the right table along with nulls for the left table. + // For instance, the current output is filtered based on the condition + // t > 1. + + // 1, 1 + // 2, 2 + // 3, 3 + + // In this scenario, we need to additionally insert a record 1, 1. + // Subsequently, we will set the values of the columns on the left to + // null and the values of the columns on the right to null as well. By + // doing so, we will obtain the final result set. + + // 1, null + // null, 1 + // 2, 2 + // 3, 3 + fullOuterOutput = BaseVector::create( + output->type(), output->size() + 1, pool()); + + for (auto i = 0; i < row + 1; i++) { + for (auto j = 0; j < output->type()->size(); j++) { + fullOuterOutput->childAt(j)->copy( + output->childAt(j).get(), i, i, 1); + } + } + + for (auto j = 0; j < output->type()->size(); j++) { + fullOuterOutput->childAt(j)->copy( + output->childAt(j).get(), row + 1, row, 1); + } + + for (auto i = row + 1; i < output->size(); i++) { + for (auto j = 0; j < output->type()->size(); j++) { + fullOuterOutput->childAt(j)->copy( + output->childAt(j).get(), i + 1, i, 1); + } + } + + for (auto& projection : leftProjections_) { + auto target = fullOuterOutput->childAt(projection.outputChannel); + target->setNull(row, true); + } + + for (auto& projection : rightProjections_) { + auto target = fullOuterOutput->childAt(projection.outputChannel); + target->setNull(row + 1, true); + } + } else if (!isRightJoin(joinType_)) { for (auto& projection : rightProjections_) { auto target = output->childAt(projection.outputChannel); target->setNull(row, true); @@ -1052,10 +1192,17 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (numPassed == numRows) { // All rows passed. + if (fullOuterOutput) { + return fullOuterOutput; + } return output; } // Some, but not all rows passed. + if (fullOuterOutput) { + return wrap(numPassed, indices, fullOuterOutput); + } + return wrap(numPassed, indices, output); } @@ -1073,6 +1220,12 @@ bool MergeJoin::isFinished() { // complete. return noMoreInput_ && noMoreRightInput_ && rightInput_ == nullptr; } + + if (isFullJoin(joinType_)) { + return noMoreInput_ && input_ == nullptr && noMoreRightInput_ && + rightInput_ == nullptr; + } + return noMoreInput_ && input_ == nullptr; } diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index 48f1e5700a2d..34102753b208 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -218,6 +218,37 @@ class MergeJoinTest : public HiveConnectorTestBase { // Test right join and left join with same result. auto expectedResult = AssertQueryBuilder(leftPlan).copyResults(pool_.get()); AssertQueryBuilder(rightPlan).assertResults(expectedResult); + + // Test FULL join. + planNodeIdGenerator = std::make_shared(); + auto fullPlan = PlanBuilder(planNodeIdGenerator) + .values(right) + .mergeJoin( + {"c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator) + .values(left) + .project({"c1 as u_c1", "c0 as u_c0"}) + .planNode(), + "", + {"u_c0", "u_c1", "c1"}, + core::JoinType::kFull) + .planNode(); + + // Use very small output batch size. + assertQuery( + makeCursorParameters(fullPlan, 16), + "SELECT t.c0, t.c1, u.c1 FROM u FULL OUTER JOIN t ON t.c0 = u.c0"); + + // Use regular output batch size. + assertQuery( + makeCursorParameters(fullPlan, 1024), + "SELECT t.c0, t.c1, u.c1 FROM u FULL OUTER JOIN t ON t.c0 = u.c0"); + + // Use very large output batch size. + assertQuery( + makeCursorParameters(fullPlan, 10'000), + "SELECT t.c0, t.c1, u.c1 FROM u FULL OUTER JOIN t ON t.c0 = u.c0"); } }; @@ -907,6 +938,131 @@ TEST_F(MergeJoinTest, antiJoinNoFilter) { "SELECT t0 FROM t WHERE NOT exists (select 1 from u where t0 = u0)"); } +TEST_F(MergeJoinTest, fullOuterJoin) { + auto left = makeRowVector( + {"t0"}, + {makeNullableFlatVector( + {1, 2, std::nullopt, 5, 6, std::nullopt})}); + + auto right = makeRowVector( + {"u0"}, + {makeNullableFlatVector( + {1, 5, 6, 8, std::nullopt, std::nullopt})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + // Full outer join. + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "t0 > 2", + {"t0", "u0"}, + core::JoinType::kFull) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults( + "SELECT * FROM t FULL OUTER JOIN u ON t.t0 = u.u0 AND t.t0 > 2"); +} + +TEST_F(MergeJoinTest, fullOuterJoinNoFilter) { + auto left = makeRowVector( + {"t0", "t1", "t2", "t3"}, + {makeNullableFlatVector( + {7854252584298216695, + 5874550437257860379, + 6694700278390749883, + 6952978413716179087, + 2785313305792069690, + 5306984336093303849, + 2249699434807719017, + std::nullopt, + std::nullopt, + std::nullopt, + 8814597374860168988}), + makeNullableFlatVector( + {1, 2, 3, 4, 5, 6, 7, std::nullopt, 8, 9, 10}), + makeNullableFlatVector( + {false, + true, + false, + false, + false, + true, + true, + false, + true, + false, + false}), + makeNullableFlatVector( + {58, 112, 125, 52, 69, 39, 73, 29, 101, std::nullopt, 51})}); + + auto right = makeRowVector( + {"u0", "u1", "u2", "u3"}, + {makeNullableFlatVector({std::nullopt}), + makeNullableFlatVector({11}), + makeNullableFlatVector({false}), + makeNullableFlatVector({77})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + // Full outer join. + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0", "t1", "t2", "t3"}, + {"u0", "u1", "u2", "u3"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t0", "t1"}, + core::JoinType::kFull) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults( + "SELECT t0, t1 FROM t FULL OUTER JOIN u ON t3 = u3 and t2 = u2 and t1 = u1 and t.t0 = u.u0"); +} + +TEST_F(MergeJoinTest, fullOuterJoinWithNullCompare) { + auto right = makeRowVector( + {"u0", "u1"}, + {makeNullableFlatVector({false, true}), + makeNullableFlatVector({std::nullopt, std::nullopt})}); + + auto left = makeRowVector( + {"t0", "t1"}, + {makeNullableFlatVector({false, false, std::nullopt}), + makeNullableFlatVector( + {std::nullopt, 1195665568, std::nullopt})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + // Full outer join. + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0", "t1"}, + {"u0", "u1"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "", + {"t0", "t1", "u0", "u1"}, + core::JoinType::kFull) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults( + "SELECT t0, t1, u0, u1 FROM t FULL OUTER JOIN u ON t.t0 = u.u0 and t1 = u1"); +} + TEST_F(MergeJoinTest, complexTypedFilter) { constexpr vector_size_t size{1000};