diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 438d7cb87689..685eb1d5cf07 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -507,6 +507,14 @@ bool MergeJoin::prepareOutput( } bool MergeJoin::addToOutput() { + if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) { + return addToOutputForRightJoin(); + } else { + return addToOutputForLeftJoin(); + } +} + +bool MergeJoin::addToOutputForLeftJoin() { size_t firstLeftBatch; vector_size_t leftStartIndex; if (leftMatch_->cursor) { @@ -552,10 +560,8 @@ bool MergeJoin::addToOutput() { // one match on the other side, we could explore specialized algorithms // or data structures that short-circuit the join process once a match // is found. - if (isLeftSemiFilterJoin(joinType_) || - isRightSemiFilterJoin(joinType_)) { + if (isLeftSemiFilterJoin(joinType_)) { // LeftSemiFilter produce each row from the left at most once. - // RightSemiFilter produce each row from the right at most once. rightEnd = rightStart + 1; } @@ -587,6 +593,84 @@ bool MergeJoin::addToOutput() { return outputSize_ == outputBatchSize_; } +bool MergeJoin::addToOutputForRightJoin() { + size_t firstRightBatch; + vector_size_t rightStartIndex; + if (rightMatch_->cursor) { + firstRightBatch = rightMatch_->cursor->batchIndex; + rightStartIndex = rightMatch_->cursor->index; + } else { + firstRightBatch = 0; + rightStartIndex = rightMatch_->startIndex; + } + + size_t numRights = rightMatch_->inputs.size(); + for (size_t r = firstRightBatch; r < numRights; ++r) { + auto right = rightMatch_->inputs[r]; + auto rightStart = r == firstRightBatch ? rightStartIndex : 0; + auto rightEnd = r == numRights - 1 ? rightMatch_->endIndex : right->size(); + + for (auto i = rightStart; i < rightEnd; ++i) { + auto firstLeftBatch = + (r == firstRightBatch && i == rightStart && leftMatch_->cursor) + ? leftMatch_->cursor->batchIndex + : 0; + + auto leftStartIndex = + (r == firstRightBatch && i == rightStart && leftMatch_->cursor) + ? leftMatch_->cursor->index + : leftMatch_->startIndex; + + auto numLefts = leftMatch_->inputs.size(); + for (size_t l = firstLeftBatch; l < numLefts; ++l) { + auto left = leftMatch_->inputs[l]; + auto leftStart = l == firstLeftBatch ? leftStartIndex : 0; + auto leftEnd = l == numLefts - 1 ? leftMatch_->endIndex : left->size(); + + if (prepareOutput(left, right)) { + output_->resize(outputSize_); + leftMatch_->setCursor(l, leftStart); + rightMatch_->setCursor(r, i); + return true; + } + + // TODO: Since semi joins only require determining if there is at least + // one match on the other side, we could explore specialized algorithms + // or data structures that short-circuit the join process once a match + // is found. + if (isRightSemiFilterJoin(joinType_)) { + // RightSemiFilter produce each row from the right at most once. + leftEnd = leftStart + 1; + } + + for (auto j = leftStart; j < leftEnd; ++j) { + if (outputSize_ == outputBatchSize_) { + // If we run out of space in the current output_, we will need to + // produce a buffer and continue processing left later. In this + // case, we cannot leave left as a lazy vector, since we cannot have + // two dictionaries wrapping the same lazy vector. + loadColumns(currentLeft_, *operatorCtx_->execCtx()); + rightMatch_->setCursor(r, i); + leftMatch_->setCursor(l, j); + return true; + } + addOutputRow(left, j, right, i); + } + } + } + } + + leftMatch_.reset(); + rightMatch_.reset(); + + // If the current key match finished, but there are still records to be + // processed in the left, we need to load lazy vectors (see comment above). + if (rightInput_ && rightIndex_ != rightInput_->size()) { + loadColumns(currentLeft_, *operatorCtx_->execCtx()); + } + return outputSize_ == outputBatchSize_; +} + namespace { vector_size_t firstNonNull( const RowVectorPtr& rowVector, @@ -649,6 +733,7 @@ RowVectorPtr MergeJoin::getOutput() { if (output != nullptr && output->size() > 0) { if (filter_) { output = applyFilter(output); + if (output != nullptr) { for (const auto [channel, _] : filterInputToOutputChannel_) { filterInput_->childAt(channel).reset(); @@ -689,7 +774,14 @@ RowVectorPtr MergeJoin::getOutput() { if (isFullJoin(joinType_)) { rightIndex_ = 0; } else { - rightIndex_ = firstNonNull(rightInput_, rightKeys_); + auto firstNonNullIndex = firstNonNull(rightInput_, rightKeys_); + if (isRightJoin(joinType_) && firstNonNullIndex > 0) { + prepareOutput(nullptr, rightInput_); + for (auto i = 0; i < firstNonNullIndex; ++i) { + addOutputRowForRightJoin(rightInput_, i); + } + } + rightIndex_ = firstNonNullIndex; if (rightIndex_ == rightInput_->size()) { // Ran out of rows on the right side. rightInput_ = nullptr; diff --git a/velox/exec/MergeJoin.h b/velox/exec/MergeJoin.h index 42222f83ae2e..3530316b90c7 100644 --- a/velox/exec/MergeJoin.h +++ b/velox/exec/MergeJoin.h @@ -180,14 +180,23 @@ class MergeJoin : public Operator { bool prepareOutput(const RowVectorPtr& left, const RowVectorPtr& right); // Appends a cartesian product of the current set of matching rows, leftMatch_ - // x rightMatch_, to output_. Returns true if output_ is full. Sets - // leftMatchCursor_ and rightMatchCursor_ if output_ filled up before all the - // rows were added. Fills up output starting from leftMatchCursor_ and - // rightMatchCursor_ positions if these are set. Clears leftMatch_ and - // rightMatch_ if all rows were added. Updates leftMatchCursor_ and - // rightMatchCursor_ if output_ filled up before all rows were added. + // x rightMatch_ for left join and rightMatch_ x leftMatch_ for right join, to + // output_. Returns true if output_ is full. Sets leftMatchCursor_ and + // rightMatchCursor_ if output_ filled up before all the rows were added. + // Fills up output starting from leftMatchCursor_ and rightMatchCursor_ + // positions if these are set. Clears leftMatch_ and rightMatch_ if all rows + // were added. Updates leftMatchCursor_ and rightMatchCursor_ if output_ + // filled up before all rows were added. bool addToOutput(); + // Appends the current set of matching rows, leftMatch_ x rightMatch_ for + // left. + bool addToOutputForLeftJoin(); + + // Appends the current set of matching rows, rightMatch_ x leftMatch_ for + // right. + bool addToOutputForRightJoin(); + // Adds one row of output by writing to the indices of the output // dictionaries. By default, this operator returns dictionaries wrapped around // the input columns from the left and right. If `isRightFlattened_`, the diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index 34102753b208..72062c641679 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -508,6 +508,82 @@ TEST_F(MergeJoinTest, leftAndRightJoinFilter) { } } +TEST_F(MergeJoinTest, rightJoinWithDuplicateMatch) { + // Each row on the left side has at most one match on the right side. + auto left = makeRowVector( + {"a", "b"}, + { + makeNullableFlatVector({1, 2, 2, 2, 3, 5, 6, std::nullopt}), + makeNullableFlatVector( + {2.0, 100.0, 1.0, 1.0, 3.0, 1.0, 6.0, std::nullopt}), + }); + + auto right = makeRowVector( + {"c", "d"}, + { + makeNullableFlatVector( + {0, 2, 2, 2, 2, 3, 4, 5, 7, std::nullopt}), + makeNullableFlatVector( + {0.0, 3.0, -1.0, -1.0, 3.0, 2.0, 1.0, 3.0, 7.0, std::nullopt}), + }); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + + auto rightPlan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"a"}, + {"c"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "b < d", + {"a", "b", "c", "d"}, + core::JoinType::kRight) + .planNode(); + AssertQueryBuilder(rightPlan, duckDbQueryRunner_) + .assertResults("SELECT * from t RIGHT JOIN u ON a = c AND b < d"); +} + +TEST_F(MergeJoinTest, rightJoinFilterWithNull) { + auto left = makeRowVector( + {"a", "b"}, + { + makeNullableFlatVector({std::nullopt, std::nullopt}), + makeNullableFlatVector({std::nullopt, std::nullopt}), + }); + + auto right = makeRowVector( + {"c", "d"}, + { + makeNullableFlatVector( + {std::nullopt, std::nullopt, std::nullopt}), + makeNullableFlatVector( + {std::nullopt, std::nullopt, std::nullopt}), + }); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + + auto rightPlan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"a"}, + {"c"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "b < d", + {"a", "b", "c", "d"}, + core::JoinType::kRight) + .planNode(); + AssertQueryBuilder(rightPlan, duckDbQueryRunner_) + .assertResults("SELECT * from t RIGHT JOIN u ON a = c AND b < d"); +} + // Verify that both left-side and right-side pipelines feeding the merge join // always run single-threaded. TEST_F(MergeJoinTest, numDrivers) {