Skip to content

Commit

Permalink
Support full outer join in smj (#10247)
Browse files Browse the repository at this point in the history
Summary:
Follow up #10148. Full join is the union of right join and left join.

Pull Request resolved: #10247

Reviewed By: pedroerp

Differential Revision: D62760535

Pulled By: kevinwilfong

fbshipit-source-id: eb5a656872581986902a1f28ee4450492d03bbe4
  • Loading branch information
JkSelf authored and facebook-github-bot committed Sep 17, 2024
1 parent af2513b commit 794b8b6
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 18 deletions.
1 change: 1 addition & 0 deletions velox/core/PlanNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ bool MergeJoinNode::isSupported(core::JoinType joinType) {
case core::JoinType::kLeftSemiFilter:
case core::JoinType::kRightSemiFilter:
case core::JoinType::kAnti:
case core::JoinType::kFull:
return true;

default:
Expand Down
189 changes: 171 additions & 18 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ void MergeJoin::initialize() {
initializeFilter(joinNode_->filter(), leftType, rightType);

if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() ||
joinNode_->isRightJoin()) {
joinNode_->isRightJoin() || joinNode_->isFullJoin()) {
joinTracker_ = JoinTracker(outputBatchSize_, pool());
}
} else if (joinNode_->isAntiJoin()) {
Expand Down Expand Up @@ -207,10 +207,23 @@ int32_t MergeJoin::compare(
const RowVectorPtr& otherBatch,
vector_size_t otherIndex) {
for (auto i = 0; i < keys.size(); ++i) {
CompareFlags compareFlags = {
.equalsOnly = true,
.nullHandlingMode =
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
auto compare = batch->childAt(keys[i])->compare(
otherBatch->childAt(otherKeys[i]).get(), index, otherIndex);
if (compare != 0) {
return compare;
otherBatch->childAt(otherKeys[i]).get(),
index,
otherIndex,
compareFlags);

// Comparing null with anything will return std::nullopt.
if (!compare.has_value()) {
// The SQL semantics of Presto and Spark will always return false if
// comparing a NULL value with any other value.
return -1;
} else if (compare.value() != 0) {
return compare.value();
}
}

Expand Down Expand Up @@ -272,7 +285,8 @@ void copyRow(
void MergeJoin::addOutputRowForLeftJoin(
const RowVectorPtr& left,
vector_size_t leftIndex) {
VELOX_USER_CHECK(isLeftJoin(joinType_) || isAntiJoin(joinType_));
VELOX_USER_CHECK(
isLeftJoin(joinType_) || isAntiJoin(joinType_) || isFullJoin(joinType_));
rawLeftIndices_[outputSize_] = leftIndex;

for (const auto& projection : rightProjections_) {
Expand All @@ -291,7 +305,7 @@ void MergeJoin::addOutputRowForLeftJoin(
void MergeJoin::addOutputRowForRightJoin(
const RowVectorPtr& right,
vector_size_t rightIndex) {
VELOX_USER_CHECK(isRightJoin(joinType_));
VELOX_USER_CHECK(isRightJoin(joinType_) || isFullJoin(joinType_));
rawRightIndices_[outputSize_] = rightIndex;

for (const auto& projection : leftProjections_) {
Expand Down Expand Up @@ -672,10 +686,14 @@ RowVectorPtr MergeJoin::getOutput() {
}

if (rightInput_) {
rightIndex_ = firstNonNull(rightInput_, rightKeys_);
if (rightIndex_ == rightInput_->size()) {
// Ran out of rows on the right side.
rightInput_ = nullptr;
if (isFullJoin(joinType_)) {
rightIndex_ = 0;
} else {
rightIndex_ = firstNonNull(rightInput_, rightKeys_);
if (rightIndex_ == rightInput_->size()) {
// Ran out of rows on the right side.
rightInput_ = nullptr;
}
}
} else {
noMoreRightInput_ = true;
Expand Down Expand Up @@ -731,10 +749,14 @@ RowVectorPtr MergeJoin::doGetOutput() {
return nullptr;
}
if (rightMatch_->inputs.back() == rightInput_) {
rightIndex_ =
firstNonNull(rightInput_, rightKeys_, rightMatch_->endIndex);
if (rightIndex_ == rightInput_->size()) {
rightInput_ = nullptr;
if (isFullJoin(joinType_)) {
rightIndex_ = rightMatch_->endIndex;
} else {
rightIndex_ =
firstNonNull(rightInput_, rightKeys_, rightMatch_->endIndex);
if (rightIndex_ == rightInput_->size()) {
rightInput_ = nullptr;
}
}
}
} else if (noMoreRightInput_) {
Expand Down Expand Up @@ -809,6 +831,62 @@ RowVectorPtr MergeJoin::doGetOutput() {
}
}

if (noMoreRightInput_ && output_) {
output_->resize(outputSize_);
return std::move(output_);
}
} else if (isFullJoin(joinType_)) {
if (input_ && noMoreRightInput_) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(input_, nullptr)) {
output_->resize(outputSize_);
return std::move(output_);
}
while (true) {
if (outputSize_ == outputBatchSize_) {
return std::move(output_);
}
addOutputRowForLeftJoin(input_, index_);

++index_;
if (index_ == input_->size()) {
// Ran out of rows on the left side.
input_ = nullptr;
return nullptr;
}
}
}

if (noMoreInput_ && output_) {
output_->resize(outputSize_);
return std::move(output_);
}

if (rightInput_ && noMoreInput_) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(nullptr, rightInput_)) {
output_->resize(outputSize_);
return std::move(output_);
}

while (true) {
if (outputSize_ == outputBatchSize_) {
return std::move(output_);
}

addOutputRowForRightJoin(rightInput_, rightIndex_);

++rightIndex_;
if (rightIndex_ == rightInput_->size()) {
// Ran out of rows on the right side.
rightInput_ = nullptr;
return nullptr;
}
}
}

if (noMoreRightInput_ && output_) {
output_->resize(outputSize_);
return std::move(output_);
Expand All @@ -833,7 +911,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
for (;;) {
// Catch up input_ with rightInput_.
while (compareResult < 0) {
if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) {
if (isLeftJoin(joinType_) || isAntiJoin(joinType_) ||
isFullJoin(joinType_)) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(input_, nullptr)) {
Expand All @@ -860,7 +939,7 @@ RowVectorPtr MergeJoin::doGetOutput() {

// Catch up rightInput_ with input_.
while (compareResult > 0) {
if (isRightJoin(joinType_)) {
if (isRightJoin(joinType_) || isFullJoin(joinType_)) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(nullptr, rightInput_)) {
Expand Down Expand Up @@ -928,7 +1007,12 @@ RowVectorPtr MergeJoin::doGetOutput() {
}

index_ = endIndex;
rightIndex_ = firstNonNull(rightInput_, rightKeys_, endRightIndex);
if (isFullJoin(joinType_)) {
rightIndex_ = endRightIndex;
} else {
rightIndex_ = firstNonNull(rightInput_, rightKeys_, endRightIndex);
}

if (rightIndex_ == rightInput_->size()) {
// Ran out of rows on the right side.
rightInput_ = nullptr;
Expand All @@ -952,6 +1036,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
const auto numRows = output->size();

RowVectorPtr fullOuterOutput = nullptr;

BufferPtr indices = allocateIndices(numRows, pool());
auto rawIndices = indices->asMutable<vector_size_t>();
vector_size_t numPassed = 0;
Expand All @@ -972,7 +1058,61 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
if (!isAntiJoin(joinType_)) {
rawIndices[numPassed++] = row;

if (!isRightJoin(joinType_)) {
if (isFullJoin(joinType_)) {
// For filtered rows, it is necessary to insert additional data
// to ensure the result set is complete. Specifically, we
// need to generate two records: one record containing the
// columns from the left table along with nulls for the
// right table, and another record containing the columns
// from the right table along with nulls for the left table.
// For instance, the current output is filtered based on the condition
// t > 1.

// 1, 1
// 2, 2
// 3, 3

// In this scenario, we need to additionally insert a record 1, 1.
// Subsequently, we will set the values of the columns on the left to
// null and the values of the columns on the right to null as well. By
// doing so, we will obtain the final result set.

// 1, null
// null, 1
// 2, 2
// 3, 3
fullOuterOutput = BaseVector::create<RowVector>(
output->type(), output->size() + 1, pool());

for (auto i = 0; i < row + 1; i++) {
for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), i, i, 1);
}
}

for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), row + 1, row, 1);
}

for (auto i = row + 1; i < output->size(); i++) {
for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), i + 1, i, 1);
}
}

for (auto& projection : leftProjections_) {
auto target = fullOuterOutput->childAt(projection.outputChannel);
target->setNull(row, true);
}

for (auto& projection : rightProjections_) {
auto target = fullOuterOutput->childAt(projection.outputChannel);
target->setNull(row + 1, true);
}
} else if (!isRightJoin(joinType_)) {
for (auto& projection : rightProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
Expand Down Expand Up @@ -1052,10 +1192,17 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

if (numPassed == numRows) {
// All rows passed.
if (fullOuterOutput) {
return fullOuterOutput;
}
return output;
}

// Some, but not all rows passed.
if (fullOuterOutput) {
return wrap(numPassed, indices, fullOuterOutput);
}

return wrap(numPassed, indices, output);
}

Expand All @@ -1073,6 +1220,12 @@ bool MergeJoin::isFinished() {
// complete.
return noMoreInput_ && noMoreRightInput_ && rightInput_ == nullptr;
}

if (isFullJoin(joinType_)) {
return noMoreInput_ && input_ == nullptr && noMoreRightInput_ &&
rightInput_ == nullptr;
}

return noMoreInput_ && input_ == nullptr;
}

Expand Down
Loading

0 comments on commit 794b8b6

Please sign in to comment.