Skip to content

Commit

Permalink
Fix the full outer join result mismatch issue with multi duplicated r…
Browse files Browse the repository at this point in the history
…ecords
  • Loading branch information
JkSelf committed Dec 31, 2024
1 parent 006579d commit 8032743
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 103 deletions.
190 changes: 92 additions & 98 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "velox/exec/Task.h"
#include "velox/expression/FieldReference.h"

#include <iostream>

namespace facebook::velox::exec {

MergeJoin::MergeJoin(
Expand Down Expand Up @@ -96,7 +98,7 @@ void MergeJoin::initialize() {
}
} else if (
joinNode_->isAntiJoin() || joinNode_->isLeftSemiFilterJoin() ||
joinNode_->isRightSemiFilterJoin()) {
joinNode_->isRightSemiFilterJoin() || joinNode_->isFullJoin()) {
// Anti join needs to track the left side rows that have no match on the
// right.
joinTracker_ = JoinTracker(outputBatchSize_, pool());
Expand Down Expand Up @@ -342,7 +344,8 @@ void MergeJoin::addOutputRow(
const RowVectorPtr& left,
vector_size_t leftIndex,
const RowVectorPtr& right,
vector_size_t rightIndex) {
vector_size_t rightIndex,
bool isRightJoinForFullOuter) {
// All left side projections share the same dictionary indices (leftIndices_).
rawLeftIndices_[outputSize_] = leftIndex;

Expand All @@ -362,24 +365,33 @@ void MergeJoin::addOutputRow(
copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_);

if (joinTracker_) {
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_) ||
(isFullJoin(joinType_) && isRightJoinForFullOuter)) {
// Record right-side row with a match on the left-side.
joinTracker_->addMatch(right, rightIndex, outputSize_);
joinTracker_->addMatch(
right, rightIndex, outputSize_, isRightJoinForFullOuter);
} else {
// Record left-side row with a match on the right-side.
joinTracker_->addMatch(left, leftIndex, outputSize_);
joinTracker_->addMatch(
left, leftIndex, outputSize_, isRightJoinForFullOuter);
}
}
} else if (isAntiJoin(joinType_) || isLeftSemiFilterJoin(joinType_)) {
} else if (
isAntiJoin(joinType_) || isLeftSemiFilterJoin(joinType_) ||
(isFullJoin(joinType_) && !isRightJoinForFullOuter)) {
// Anti join needs to track the left side rows that have no match on the
// right.
VELOX_CHECK(joinTracker_);
// Record left-side row with a match on the right-side.
joinTracker_->addMatch(left, leftIndex, outputSize_);
} else if (isRightSemiFilterJoin(joinType_)) {
joinTracker_->addMatch(
left, leftIndex, outputSize_, isRightJoinForFullOuter);
} else if (
isRightSemiFilterJoin(joinType_) ||
(isFullJoin(joinType_) && isRightJoinForFullOuter)) {
VELOX_CHECK(joinTracker_);
// Record left-side row with a match on the right-side.
joinTracker_->addMatch(right, rightIndex, outputSize_);
joinTracker_->addMatch(
right, rightIndex, outputSize_, isRightJoinForFullOuter);
}

++outputSize_;
Expand All @@ -396,14 +408,14 @@ bool MergeJoin::prepareOutput(
return true;
}

if (right != currentRight_) {
return true;
}

// If there is a new right, we need to flatten the dictionary.
if (!isRightFlattened_ && right && currentRight_ != right) {
flattenRightProjections();
}

if (right != currentRight_) {
return true;
}
return false;
}

Expand Down Expand Up @@ -515,6 +527,39 @@ bool MergeJoin::prepareOutput(
bool MergeJoin::addToOutput() {
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
return addToOutputForRightJoin();
} else if (isFullJoin(joinType_) && filter_) {
if (!leftForRightJoinMatch_) {
leftForRightJoinMatch_ = leftMatch_;
rightForRightJoinMatch_ = rightMatch_;
}

if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) {
auto left = addToOutputForLeftJoin();
if (!leftMatch_) {
leftJoinForFullFinished_ = true;
}
if (left) {
if (!leftMatch_) {
leftMatch_ = leftForRightJoinMatch_;
rightMatch_ = rightForRightJoinMatch_;
}

return true;
}
}

if (!leftMatch_ && !rightJoinForFullFinished_) {
leftMatch_ = leftForRightJoinMatch_;
rightMatch_ = rightForRightJoinMatch_;
rightJoinForFullFinished_ = true;
}

auto right = addToOutputForRightJoin();

leftForRightJoinMatch_ = leftMatch_;
rightForRightJoinMatch_ = rightMatch_;

return right;
} else {
return addToOutputForLeftJoin();
}
Expand Down Expand Up @@ -660,7 +705,12 @@ bool MergeJoin::addToOutputForRightJoin() {
leftMatch_->setCursor(l, j);
return true;
}
addOutputRow(left, j, right, i);

if (isFullJoin(joinType_)) {
addOutputRow(left, j, right, i, true);
} else {
addOutputRow(left, j, right, i);
}
}
}
}
Expand Down Expand Up @@ -713,7 +763,7 @@ RowVectorPtr MergeJoin::filterOutputForSemiJoin(const RowVectorPtr& output) {

// If all matches for a given left-side row fail the filter, add a row to
// the output with nulls for the right-side columns.
auto onMiss = [&](auto row) {};
auto onMiss = [&](auto row, bool flag) {};

auto onMatch = [&](auto row) {
if (isLeftSemiFilterJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
Expand Down Expand Up @@ -817,22 +867,19 @@ RowVectorPtr MergeJoin::getOutput() {
}

if (rightInput_) {
if (isFullJoin(joinType_)) {
rightIndex_ = 0;
} else {
auto firstNonNullIndex = firstNonNull(rightInput_, rightKeys_);
if (isRightJoin(joinType_) && firstNonNullIndex > 0) {
prepareOutput(nullptr, rightInput_);
for (auto i = 0; i < firstNonNullIndex; ++i) {
addOutputRowForRightJoin(rightInput_, i);
}
}
rightIndex_ = firstNonNullIndex;
if (finishedRightBatch()) {
// Ran out of rows on the right side.
rightInput_ = nullptr;
auto firstNonNullIndex = firstNonNull(rightInput_, rightKeys_);
if ((isRightJoin(joinType_) || isFullJoin(joinType_)) &&
firstNonNullIndex > 0) {
prepareOutput(nullptr, rightInput_);
for (auto i = 0; i < firstNonNullIndex; ++i) {
addOutputRowForRightJoin(rightInput_, i);
}
}
rightIndex_ = firstNonNullIndex;
if (finishedRightBatch()) {
// Ran out of rows on the right side.
rightInput_ = nullptr;
}
} else {
noMoreRightInput_ = true;
}
Expand Down Expand Up @@ -1054,7 +1101,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
isFullJoin(joinType_)) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(input_, nullptr)) {
if (prepareOutput(input_, rightInput_)) {
output_->resize(outputSize_);
return std::move(output_);
}
Expand All @@ -1080,7 +1127,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
if (isRightJoin(joinType_) || isFullJoin(joinType_)) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(nullptr, rightInput_)) {
if (prepareOutput(input_, rightInput_)) {
output_->resize(outputSize_);
return std::move(output_);
}
Expand Down Expand Up @@ -1129,6 +1176,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
endRightIndex < rightInput_->size(),
std::nullopt};

leftJoinForFullFinished_ = false;
rightJoinForFullFinished_ = false;
if (!leftMatch_->complete || !rightMatch_->complete) {
if (!leftMatch_->complete) {
// Need to continue looking for the end of match.
Expand All @@ -1143,6 +1192,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
}

index_ = endIndex;

if (isFullJoin(joinType_)) {
rightIndex_ = endRightIndex;
} else {
Expand Down Expand Up @@ -1174,8 +1224,6 @@ RowVectorPtr MergeJoin::doGetOutput() {
RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
const auto numRows = output->size();

RowVectorPtr fullOuterOutput = nullptr;

BufferPtr indices = allocateIndices(numRows, pool());
auto rawIndices = indices->asMutable<vector_size_t>();
vector_size_t numPassed = 0;
Expand All @@ -1192,70 +1240,23 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

// If all matches for a given left-side row fail the filter, add a row to
// the output with nulls for the right-side columns.
auto onMiss = [&](auto row) {
auto onMiss = [&](auto row, bool flag) {
if (!isLeftSemiFilterJoin(joinType_) &&
!isRightSemiFilterJoin(joinType_)) {
rawIndices[numPassed++] = row;

if (isFullJoin(joinType_)) {
// For filtered rows, it is necessary to insert additional data
// to ensure the result set is complete. Specifically, we
// need to generate two records: one record containing the
// columns from the left table along with nulls for the
// right table, and another record containing the columns
// from the right table along with nulls for the left table.
// For instance, the current output is filtered based on the condition
// t > 1.

// 1, 1
// 2, 2
// 3, 3

// In this scenario, we need to additionally insert a record 1, 1.
// Subsequently, we will set the values of the columns on the left to
// null and the values of the columns on the right to null as well. By
// doing so, we will obtain the final result set.

// 1, null
// null, 1
// 2, 2
// 3, 3
fullOuterOutput = BaseVector::create<RowVector>(
output->type(), output->size() + 1, pool());

for (auto i = 0; i < row + 1; i++) {
for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), i, i, 1);
if (!isRightJoin(joinType_)) {
if (isFullJoin(joinType_) && flag) {
for (auto& projection : leftProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
}
}

for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), row + 1, row, 1);
}

for (auto i = row + 1; i < output->size(); i++) {
for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), i + 1, i, 1);
} else {
for (auto& projection : rightProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
}
}

for (auto& projection : leftProjections_) {
auto target = fullOuterOutput->childAt(projection.outputChannel);
target->setNull(row, true);
}

for (auto& projection : rightProjections_) {
auto target = fullOuterOutput->childAt(projection.outputChannel);
target->setNull(row + 1, true);
}
} else if (!isRightJoin(joinType_)) {
for (auto& projection : rightProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
}
} else {
for (auto& projection : leftProjections_) {
auto target = output->childAt(projection.outputChannel);
Expand All @@ -1280,7 +1281,7 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

if (!isAntiJoin(joinType_) && !isLeftSemiFilterJoin(joinType_) &&
!isRightSemiFilterJoin(joinType_)) {
if (passed) {
if (passed && !joinTracker_->isRightJoinForFullOuter(i)) {
rawIndices[numPassed++] = i;
}
}
Expand Down Expand Up @@ -1338,17 +1339,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

if (numPassed == numRows) {
// All rows passed.
if (fullOuterOutput) {
return fullOuterOutput;
}
return output;
}

// Some, but not all rows passed.
if (fullOuterOutput) {
return wrap(numPassed, indices, fullOuterOutput);
}

return wrap(numPassed, indices, output);
}

Expand Down
Loading

0 comments on commit 8032743

Please sign in to comment.