From 8032743671549cdcbb9d5080083cded286509556 Mon Sep 17 00:00:00 2001 From: Jia Ke Date: Tue, 31 Dec 2024 17:48:16 +0800 Subject: [PATCH] Fix the full outer join result mismatch issue with multi duplicated records --- velox/exec/MergeJoin.cpp | 190 ++++++++++++++--------------- velox/exec/MergeJoin.h | 31 ++++- velox/exec/tests/MergeJoinTest.cpp | 39 ++++++ 3 files changed, 157 insertions(+), 103 deletions(-) diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index c1c5a120bce5..3d59b2229dbe 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -18,6 +18,8 @@ #include "velox/exec/Task.h" #include "velox/expression/FieldReference.h" +#include + namespace facebook::velox::exec { MergeJoin::MergeJoin( @@ -96,7 +98,7 @@ void MergeJoin::initialize() { } } else if ( joinNode_->isAntiJoin() || joinNode_->isLeftSemiFilterJoin() || - joinNode_->isRightSemiFilterJoin()) { + joinNode_->isRightSemiFilterJoin() || joinNode_->isFullJoin()) { // Anti join needs to track the left side rows that have no match on the // right. joinTracker_ = JoinTracker(outputBatchSize_, pool()); @@ -342,7 +344,8 @@ void MergeJoin::addOutputRow( const RowVectorPtr& left, vector_size_t leftIndex, const RowVectorPtr& right, - vector_size_t rightIndex) { + vector_size_t rightIndex, + bool isRightJoinForFullOuter) { // All left side projections share the same dictionary indices (leftIndices_). rawLeftIndices_[outputSize_] = leftIndex; @@ -362,24 +365,33 @@ void MergeJoin::addOutputRow( copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_); if (joinTracker_) { - if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) { + if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_) || + (isFullJoin(joinType_) && isRightJoinForFullOuter)) { // Record right-side row with a match on the left-side. - joinTracker_->addMatch(right, rightIndex, outputSize_); + joinTracker_->addMatch( + right, rightIndex, outputSize_, isRightJoinForFullOuter); } else { // Record left-side row with a match on the right-side. - joinTracker_->addMatch(left, leftIndex, outputSize_); + joinTracker_->addMatch( + left, leftIndex, outputSize_, isRightJoinForFullOuter); } } - } else if (isAntiJoin(joinType_) || isLeftSemiFilterJoin(joinType_)) { + } else if ( + isAntiJoin(joinType_) || isLeftSemiFilterJoin(joinType_) || + (isFullJoin(joinType_) && !isRightJoinForFullOuter)) { // Anti join needs to track the left side rows that have no match on the // right. VELOX_CHECK(joinTracker_); // Record left-side row with a match on the right-side. - joinTracker_->addMatch(left, leftIndex, outputSize_); - } else if (isRightSemiFilterJoin(joinType_)) { + joinTracker_->addMatch( + left, leftIndex, outputSize_, isRightJoinForFullOuter); + } else if ( + isRightSemiFilterJoin(joinType_) || + (isFullJoin(joinType_) && isRightJoinForFullOuter)) { VELOX_CHECK(joinTracker_); // Record left-side row with a match on the right-side. - joinTracker_->addMatch(right, rightIndex, outputSize_); + joinTracker_->addMatch( + right, rightIndex, outputSize_, isRightJoinForFullOuter); } ++outputSize_; @@ -396,14 +408,14 @@ bool MergeJoin::prepareOutput( return true; } - if (right != currentRight_) { - return true; - } - // If there is a new right, we need to flatten the dictionary. if (!isRightFlattened_ && right && currentRight_ != right) { flattenRightProjections(); } + + if (right != currentRight_) { + return true; + } return false; } @@ -515,6 +527,39 @@ bool MergeJoin::prepareOutput( bool MergeJoin::addToOutput() { if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) { return addToOutputForRightJoin(); + } else if (isFullJoin(joinType_) && filter_) { + if (!leftForRightJoinMatch_) { + leftForRightJoinMatch_ = leftMatch_; + rightForRightJoinMatch_ = rightMatch_; + } + + if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) { + auto left = addToOutputForLeftJoin(); + if (!leftMatch_) { + leftJoinForFullFinished_ = true; + } + if (left) { + if (!leftMatch_) { + leftMatch_ = leftForRightJoinMatch_; + rightMatch_ = rightForRightJoinMatch_; + } + + return true; + } + } + + if (!leftMatch_ && !rightJoinForFullFinished_) { + leftMatch_ = leftForRightJoinMatch_; + rightMatch_ = rightForRightJoinMatch_; + rightJoinForFullFinished_ = true; + } + + auto right = addToOutputForRightJoin(); + + leftForRightJoinMatch_ = leftMatch_; + rightForRightJoinMatch_ = rightMatch_; + + return right; } else { return addToOutputForLeftJoin(); } @@ -660,7 +705,12 @@ bool MergeJoin::addToOutputForRightJoin() { leftMatch_->setCursor(l, j); return true; } - addOutputRow(left, j, right, i); + + if (isFullJoin(joinType_)) { + addOutputRow(left, j, right, i, true); + } else { + addOutputRow(left, j, right, i); + } } } } @@ -713,7 +763,7 @@ RowVectorPtr MergeJoin::filterOutputForSemiJoin(const RowVectorPtr& output) { // If all matches for a given left-side row fail the filter, add a row to // the output with nulls for the right-side columns. - auto onMiss = [&](auto row) {}; + auto onMiss = [&](auto row, bool flag) {}; auto onMatch = [&](auto row) { if (isLeftSemiFilterJoin(joinType_) || isRightSemiFilterJoin(joinType_)) { @@ -817,22 +867,19 @@ RowVectorPtr MergeJoin::getOutput() { } if (rightInput_) { - if (isFullJoin(joinType_)) { - rightIndex_ = 0; - } else { - auto firstNonNullIndex = firstNonNull(rightInput_, rightKeys_); - if (isRightJoin(joinType_) && firstNonNullIndex > 0) { - prepareOutput(nullptr, rightInput_); - for (auto i = 0; i < firstNonNullIndex; ++i) { - addOutputRowForRightJoin(rightInput_, i); - } - } - rightIndex_ = firstNonNullIndex; - if (finishedRightBatch()) { - // Ran out of rows on the right side. - rightInput_ = nullptr; + auto firstNonNullIndex = firstNonNull(rightInput_, rightKeys_); + if ((isRightJoin(joinType_) || isFullJoin(joinType_)) && + firstNonNullIndex > 0) { + prepareOutput(nullptr, rightInput_); + for (auto i = 0; i < firstNonNullIndex; ++i) { + addOutputRowForRightJoin(rightInput_, i); } } + rightIndex_ = firstNonNullIndex; + if (finishedRightBatch()) { + // Ran out of rows on the right side. + rightInput_ = nullptr; + } } else { noMoreRightInput_ = true; } @@ -1054,7 +1101,7 @@ RowVectorPtr MergeJoin::doGetOutput() { isFullJoin(joinType_)) { // If output_ is currently wrapping a different buffer, return it // first. - if (prepareOutput(input_, nullptr)) { + if (prepareOutput(input_, rightInput_)) { output_->resize(outputSize_); return std::move(output_); } @@ -1080,7 +1127,7 @@ RowVectorPtr MergeJoin::doGetOutput() { if (isRightJoin(joinType_) || isFullJoin(joinType_)) { // If output_ is currently wrapping a different buffer, return it // first. - if (prepareOutput(nullptr, rightInput_)) { + if (prepareOutput(input_, rightInput_)) { output_->resize(outputSize_); return std::move(output_); } @@ -1129,6 +1176,8 @@ RowVectorPtr MergeJoin::doGetOutput() { endRightIndex < rightInput_->size(), std::nullopt}; + leftJoinForFullFinished_ = false; + rightJoinForFullFinished_ = false; if (!leftMatch_->complete || !rightMatch_->complete) { if (!leftMatch_->complete) { // Need to continue looking for the end of match. @@ -1143,6 +1192,7 @@ RowVectorPtr MergeJoin::doGetOutput() { } index_ = endIndex; + if (isFullJoin(joinType_)) { rightIndex_ = endRightIndex; } else { @@ -1174,8 +1224,6 @@ RowVectorPtr MergeJoin::doGetOutput() { RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { const auto numRows = output->size(); - RowVectorPtr fullOuterOutput = nullptr; - BufferPtr indices = allocateIndices(numRows, pool()); auto rawIndices = indices->asMutable(); vector_size_t numPassed = 0; @@ -1192,70 +1240,23 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // If all matches for a given left-side row fail the filter, add a row to // the output with nulls for the right-side columns. - auto onMiss = [&](auto row) { + auto onMiss = [&](auto row, bool flag) { if (!isLeftSemiFilterJoin(joinType_) && !isRightSemiFilterJoin(joinType_)) { rawIndices[numPassed++] = row; - if (isFullJoin(joinType_)) { - // For filtered rows, it is necessary to insert additional data - // to ensure the result set is complete. Specifically, we - // need to generate two records: one record containing the - // columns from the left table along with nulls for the - // right table, and another record containing the columns - // from the right table along with nulls for the left table. - // For instance, the current output is filtered based on the condition - // t > 1. - - // 1, 1 - // 2, 2 - // 3, 3 - - // In this scenario, we need to additionally insert a record 1, 1. - // Subsequently, we will set the values of the columns on the left to - // null and the values of the columns on the right to null as well. By - // doing so, we will obtain the final result set. - - // 1, null - // null, 1 - // 2, 2 - // 3, 3 - fullOuterOutput = BaseVector::create( - output->type(), output->size() + 1, pool()); - - for (auto i = 0; i < row + 1; i++) { - for (auto j = 0; j < output->type()->size(); j++) { - fullOuterOutput->childAt(j)->copy( - output->childAt(j).get(), i, i, 1); + if (!isRightJoin(joinType_)) { + if (isFullJoin(joinType_) && flag) { + for (auto& projection : leftProjections_) { + auto target = output->childAt(projection.outputChannel); + target->setNull(row, true); } - } - - for (auto j = 0; j < output->type()->size(); j++) { - fullOuterOutput->childAt(j)->copy( - output->childAt(j).get(), row + 1, row, 1); - } - - for (auto i = row + 1; i < output->size(); i++) { - for (auto j = 0; j < output->type()->size(); j++) { - fullOuterOutput->childAt(j)->copy( - output->childAt(j).get(), i + 1, i, 1); + } else { + for (auto& projection : rightProjections_) { + auto target = output->childAt(projection.outputChannel); + target->setNull(row, true); } } - - for (auto& projection : leftProjections_) { - auto target = fullOuterOutput->childAt(projection.outputChannel); - target->setNull(row, true); - } - - for (auto& projection : rightProjections_) { - auto target = fullOuterOutput->childAt(projection.outputChannel); - target->setNull(row + 1, true); - } - } else if (!isRightJoin(joinType_)) { - for (auto& projection : rightProjections_) { - auto target = output->childAt(projection.outputChannel); - target->setNull(row, true); - } } else { for (auto& projection : leftProjections_) { auto target = output->childAt(projection.outputChannel); @@ -1280,7 +1281,7 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (!isAntiJoin(joinType_) && !isLeftSemiFilterJoin(joinType_) && !isRightSemiFilterJoin(joinType_)) { - if (passed) { + if (passed && !joinTracker_->isRightJoinForFullOuter(i)) { rawIndices[numPassed++] = i; } } @@ -1338,17 +1339,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (numPassed == numRows) { // All rows passed. - if (fullOuterOutput) { - return fullOuterOutput; - } return output; } // Some, but not all rows passed. - if (fullOuterOutput) { - return wrap(numPassed, indices, fullOuterOutput); - } - return wrap(numPassed, indices, output); } diff --git a/velox/exec/MergeJoin.h b/velox/exec/MergeJoin.h index ba9269448882..d62768efb7c8 100644 --- a/velox/exec/MergeJoin.h +++ b/velox/exec/MergeJoin.h @@ -207,7 +207,8 @@ class MergeJoin : public Operator { const RowVectorPtr& left, vector_size_t leftIndex, const RowVectorPtr& right, - vector_size_t rightIndex); + vector_size_t rightIndex, + bool isRightJoinForFullOuter = false); // If the right side projected columns in the current output vector happen to // span more than one vector from the right side, they cannot be simply @@ -299,6 +300,9 @@ class MergeJoin : public Operator { : matchingRows_{numRows, false} { leftRowNumbers_ = AlignedBuffer::allocate(numRows, pool); rawLeftRowNumbers_ = leftRowNumbers_->asMutable(); + + rightJoinRows_ = AlignedBuffer::allocate(numRows, pool); + rawRightJoinRows_ = rightJoinRows_->asMutable(); } /// Records a row of output that corresponds to a match between a left-side @@ -309,7 +313,8 @@ class MergeJoin : public Operator { void addMatch( const VectorPtr& left, vector_size_t leftIndex, - vector_size_t outputIndex) { + vector_size_t outputIndex, + bool rightJoinForFullOuter = false) { matchingRows_.setValid(outputIndex, true); if (lastVector_ != left || lastIndex_ != leftIndex) { @@ -320,6 +325,7 @@ class MergeJoin : public Operator { } rawLeftRowNumbers_[outputIndex] = lastLeftRowNumber_; + rawRightJoinRows_[outputIndex] = rightJoinForFullOuter; } /// Returns a subset of "match" rows in [0, numRows) range that were @@ -362,7 +368,7 @@ class MergeJoin : public Operator { auto rowNumber = rawLeftRowNumbers_[outputIndex]; if (currentLeftRowNumber_ != rowNumber) { if (currentRow_ != -1 && !currentRowPassed_) { - onMiss(currentRow_); + onMiss(currentRow_, rawRightJoinRows_[currentRow_]); } currentRow_ = outputIndex; currentLeftRowNumber_ = rowNumber; @@ -394,8 +400,8 @@ class MergeJoin : public Operator { /// filter failed for all matches of that row. template void noMoreFilterResults(TOnMiss onMiss) { - if (!currentRowPassed_) { - onMiss(currentRow_); + if (!currentRowPassed_ && currentRow_ >= 0) { + onMiss(currentRow_, rawRightJoinRows_[currentRow_]); } currentRow_ = -1; @@ -403,6 +409,10 @@ class MergeJoin : public Operator { firstMatched_ = false; } + bool isRightJoinForFullOuter(vector_size_t row) { + return rawRightJoinRows_[row]; + } + private: // A subset of output rows where left side matched right side on the join // keys. Used in filter evaluation. @@ -422,6 +432,9 @@ class MergeJoin : public Operator { BufferPtr leftRowNumbers_; vector_size_t* rawLeftRowNumbers_; + BufferPtr rightJoinRows_; + bool* rawRightJoinRows_; + // Synthetic number assigned to the last added "match" row or zero if no row // has been added yet. vector_size_t lastLeftRowNumber_{0}; @@ -534,5 +547,13 @@ class MergeJoin : public Operator { // True if all the right side data has been received. bool noMoreRightInput_{false}; + + bool leftJoinForFullFinished_{false}; + + bool rightJoinForFullFinished_{false}; + + std::optional leftForRightJoinMatch_; + + std::optional rightForRightJoinMatch_; }; } // namespace facebook::velox::exec diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index e1f8ff206f0d..512252e07a27 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -1301,6 +1301,45 @@ TEST_F(MergeJoinTest, fullOuterJoin) { "SELECT * FROM t FULL OUTER JOIN u ON t.t0 = u.u0 AND t.t0 > 2"); } +TEST_F(MergeJoinTest, fullOuterJoinWithDuplicateMatch) { + // Each row on the left side has at most one match on the right side. + auto left = makeRowVector( + {"a", "b"}, + { + makeNullableFlatVector({1, 2, 2, 2, 3, 5, 6, std::nullopt}), + makeNullableFlatVector( + {2.0, 100.0, 1.0, 1.0, 3.0, 1.0, 6.0, std::nullopt}), + }); + + auto right = makeRowVector( + {"c", "d"}, + { + makeNullableFlatVector( + {0, 2, 2, 2, 2, 3, 4, 5, 7, std::nullopt}), + makeNullableFlatVector( + {0.0, 3.0, -1.0, -1.0, 3.0, 2.0, 1.0, 3.0, 7.0, std::nullopt}), + }); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + + auto rightPlan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"a"}, + {"c"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "b < d", + {"a", "b", "c", "d"}, + core::JoinType::kFull) + .planNode(); + AssertQueryBuilder(rightPlan, duckDbQueryRunner_) + .assertResults("SELECT * from t FULL OUTER JOIN u ON a = c AND b < d"); +} + TEST_F(MergeJoinTest, fullOuterJoinNoFilter) { auto left = makeRowVector( {"t0", "t1", "t2", "t3"},