Skip to content

Commit

Permalink
Fix semi join and anti join result mismatch issue
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Dec 9, 2024
1 parent 929affe commit 77d2d90
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 32 deletions.
116 changes: 86 additions & 30 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,14 @@ void MergeJoin::initialize() {
initializeFilter(joinNode_->filter(), leftType, rightType);

if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() ||
joinNode_->isRightJoin() || joinNode_->isFullJoin()) {
joinNode_->isRightJoin() || joinNode_->isFullJoin() ||
joinNode_->isLeftSemiFilterJoin() ||
joinNode_->isRightSemiFilterJoin()) {
joinTracker_ = JoinTracker(outputBatchSize_, pool());
}
} else if (joinNode_->isAntiJoin()) {
} else if (
joinNode_->isAntiJoin() || joinNode_->isLeftSemiFilterJoin() ||
joinNode_->isRightSemiFilterJoin()) {
// Anti join needs to track the left side rows that have no match on the
// right.
joinTracker_ = JoinTracker(outputBatchSize_, pool());
Expand Down Expand Up @@ -358,22 +362,24 @@ void MergeJoin::addOutputRow(
copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_);

if (joinTracker_) {
if (isRightJoin(joinType_)) {
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
// Record right-side row with a match on the left-side.
joinTracker_->addMatch(right, rightIndex, outputSize_);
} else {
// Record left-side row with a match on the right-side.
joinTracker_->addMatch(left, leftIndex, outputSize_);
}
}
}

// Anti join needs to track the left side rows that have no match on the
// right.
if (isAntiJoin(joinType_)) {
} else if (isAntiJoin(joinType_) || isLeftSemiFilterJoin(joinType_)) {
// Anti join needs to track the left side rows that have no match on the
// right.
VELOX_CHECK(joinTracker_);
// Record left-side row with a match on the right-side.
joinTracker_->addMatch(left, leftIndex, outputSize_);
} else if (isRightSemiFilterJoin(joinType_)) {
VELOX_CHECK(joinTracker_);
// Record left-side row with a match on the right-side.
joinTracker_->addMatch(right, rightIndex, outputSize_);
}

++outputSize_;
Expand All @@ -390,7 +396,7 @@ bool MergeJoin::prepareOutput(
return true;
}

if (isRightJoin(joinType_) && right != currentRight_) {
if (right != currentRight_) {
return true;
}

Expand Down Expand Up @@ -560,7 +566,7 @@ bool MergeJoin::addToOutputForLeftJoin() {
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
if (isLeftSemiFilterJoin(joinType_)) {
if (isLeftSemiFilterJoin(joinType_) && !filter_) {
// LeftSemiFilter produce each row from the left at most once.
rightEnd = rightStart + 1;
}
Expand Down Expand Up @@ -638,7 +644,7 @@ bool MergeJoin::addToOutputForRightJoin() {
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
if (isRightSemiFilterJoin(joinType_)) {
if (isRightSemiFilterJoin(joinType_) && !filter_) {
// RightSemiFilter produce each row from the right at most once.
leftEnd = leftStart + 1;
}
Expand Down Expand Up @@ -693,6 +699,37 @@ vector_size_t firstNonNull(
}
} // namespace

RowVectorPtr MergeJoin::filterOutputForSemiJoin(const RowVectorPtr& output) {
const auto numRows = output->size();
const auto& matchedRows = joinTracker_->matchingRows(numRows);
const auto numPassed = matchedRows.countSelected();
if (numPassed == 0) {
return nullptr;
}

BufferPtr indices = allocateIndices(numPassed, pool());
auto* rawIndices = indices->asMutable<vector_size_t>();
size_t index{0};

// If all matches for a given left-side row fail the filter, add a row to
// the output with nulls for the right-side columns.
auto onMiss = [&](auto row) {};

auto onMatch = [&](auto row) {
if (isLeftSemiFilterJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
rawIndices[index++] = row;
}
};
for (auto i = 0; i < numRows; ++i) {
if (matchedRows.isValid(i)) {
joinTracker_->processFilterResult(i, true, onMiss, onMatch);
}
}

// Some, but not all rows passed.
return wrap(index, indices, output);
}

RowVectorPtr MergeJoin::filterOutputForAntiJoin(const RowVectorPtr& output) {
const auto numRows = output->size();
const auto& filterRows = joinTracker_->matchingRows(numRows);
Expand Down Expand Up @@ -745,7 +782,16 @@ RowVectorPtr MergeJoin::getOutput() {
continue;
} else if (isAntiJoin(joinType_)) {
output = filterOutputForAntiJoin(output);
if (output) {
if (output != nullptr && output->size() > 0) {
return output;
}

// No rows survived the filter for anti join. Get more rows.
continue;
} else if (
isLeftSemiFilterJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
output = filterOutputForSemiJoin(output);
if (output != nullptr && output->size() > 0) {
return output;
}

Expand Down Expand Up @@ -806,7 +852,9 @@ RowVectorPtr MergeJoin::doGetOutput() {

// Not all rows from the last match fit in the output. Continue producing
// results from the current match.
if (addToOutput()) {
addToOutput();
if (outputSize_ > 0) {
output_->resize(outputSize_);
return std::move(output_);
}
}
Expand Down Expand Up @@ -865,7 +913,9 @@ RowVectorPtr MergeJoin::doGetOutput() {
VELOX_CHECK(leftMatch_->complete);
VELOX_CHECK(rightMatch_ && rightMatch_->complete);

if (addToOutput()) {
addToOutput();
if (outputSize_ > 0) {
output_->resize(outputSize_);
return std::move(output_);
}
}
Expand Down Expand Up @@ -1104,7 +1154,9 @@ RowVectorPtr MergeJoin::doGetOutput() {
rightInput_ = nullptr;
}

if (addToOutput()) {
addToOutput();
if (outputSize_ > 0) {
output_->resize(outputSize_);
return std::move(output_);
}

Expand Down Expand Up @@ -1141,7 +1193,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
// If all matches for a given left-side row fail the filter, add a row to
// the output with nulls for the right-side columns.
auto onMiss = [&](auto row) {
if (!isAntiJoin(joinType_)) {
if (!isLeftSemiFilterJoin(joinType_) &&
!isRightSemiFilterJoin(joinType_)) {
rawIndices[numPassed++] = row;

if (isFullJoin(joinType_)) {
Expand Down Expand Up @@ -1212,18 +1265,21 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
}
};

auto onMatch = [&](auto row) {
if (isLeftSemiFilterJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
rawIndices[numPassed++] = row;
}
};

for (auto i = 0; i < numRows; ++i) {
if (filterRows.isValid(i)) {
const bool passed = !decodedFilterResult_.isNullAt(i) &&
decodedFilterResult_.valueAt<bool>(i);

joinTracker_->processFilterResult(i, passed, onMiss);
joinTracker_->processFilterResult(i, passed, onMiss, onMatch);

if (isAntiJoin(joinType_)) {
if (!passed) {
rawIndices[numPassed++] = i;
}
} else {
if (!isAntiJoin(joinType_) && !isLeftSemiFilterJoin(joinType_) &&
!isRightSemiFilterJoin(joinType_)) {
if (passed) {
rawIndices[numPassed++] = i;
}
Expand All @@ -1237,19 +1293,19 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

// Every time we start a new left key match, `processFilterResult()` will
// check if at least one row from the previous match passed the filter. If
// none did, it calls onMiss to add a record with null right projections to
// the output.
// none did, it calls onMiss to add a record with null right projections
// to the output.
//
// Before we leave the current buffer, since we may not have seen the next
// left key match yet, the last key match may still be pending to produce a
// row (because `processFilterResult()` was not called yet).
// left key match yet, the last key match may still be pending to produce
// a row (because `processFilterResult()` was not called yet).
//
// To handle this, we need to call `noMoreFilterResults()` unless the
// same current left key match may continue in the next buffer. So there are
// two cases to check:
// same current left key match may continue in the next buffer. So there
// are two cases to check:
//
// 1. If leftMatch_ is nullopt, there for sure the next buffer will contain
// a different key match.
// 1. If leftMatch_ is nullopt, there for sure the next buffer will
// contain a different key match.
//
// 2. leftMatch_ may not be nullopt, but may be related to a different
// (subsequent) left key. So we check if the last row in the batch has the
Expand Down
16 changes: 14 additions & 2 deletions velox/exec/MergeJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ class MergeJoin : public Operator {
/// rows from the left side that have a match on the right.
RowVectorPtr filterOutputForAntiJoin(const RowVectorPtr& output);

RowVectorPtr filterOutputForSemiJoin(const RowVectorPtr& output);

/// As we populate the results of the join, we track whether a given
/// output row is a result of a match between left and right sides or a miss.
/// We use JoinTracker::addMatch and addMiss methods for that.
Expand Down Expand Up @@ -351,11 +353,12 @@ class MergeJoin : public Operator {
/// rows that correspond to a single left-side row. Use
/// 'noMoreFilterResults' to make sure 'onMiss' is called for the last
/// left-side row.
template <typename TOnMiss>
template <typename TOnMiss, typename TOnMatch>
void processFilterResult(
vector_size_t outputIndex,
bool passed,
TOnMiss onMiss) {
TOnMiss onMiss,
TOnMatch onMatch) {
auto rowNumber = rawLeftRowNumbers_[outputIndex];
if (currentLeftRowNumber_ != rowNumber) {
if (currentRow_ != -1 && !currentRowPassed_) {
Expand All @@ -364,12 +367,18 @@ class MergeJoin : public Operator {
currentRow_ = outputIndex;
currentLeftRowNumber_ = rowNumber;
currentRowPassed_ = false;
firstMatched_ = false;
} else {
currentRow_ = outputIndex;
}

if (passed) {
currentRowPassed_ = true;

if (!firstMatched_) {
onMatch(outputIndex);
firstMatched_ = true;
}
}
}

Expand All @@ -391,6 +400,7 @@ class MergeJoin : public Operator {

currentRow_ = -1;
currentRowPassed_ = false;
firstMatched_ = false;
}

private:
Expand Down Expand Up @@ -425,6 +435,8 @@ class MergeJoin : public Operator {
// True if at least one row in a block of output rows corresponding a single
// left-side row identified by 'currentRowNumber' passed the filter.
bool currentRowPassed_{false};

bool firstMatched_{false};
};

/// Used to record both left and right join.
Expand Down
Loading

0 comments on commit 77d2d90

Please sign in to comment.