Skip to content

Commit

Permalink
Refine semi join (#9672)
Browse files Browse the repository at this point in the history
ref #9433

Signed-off-by: xufei <[email protected]>
  • Loading branch information
windtalker authored Nov 27, 2024
1 parent 0c15f1f commit 3f0df40
Show file tree
Hide file tree
Showing 5 changed files with 374 additions and 300 deletions.
225 changes: 67 additions & 158 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1705,26 +1705,8 @@ Block Join::joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_in

Block Join::joinBlockSemi(ProbeProcessInfo & probe_process_info) const
{
JoinBuildInfo join_build_info{
enable_fine_grained_shuffle,
fine_grained_shuffle_count,
isEnableSpill(),
hash_join_spill_context->isSpilled(),
build_concurrency,
restore_config.restore_round};

probe_process_info.prepareForHashProbe(
key_names_left,
non_equal_conditions.left_filter_column,
kind,
strictness,
join_build_info.needVirtualDispatchForProbeBlock(),
collators,
restore_config.restore_round);

Block block{};
#define CALL(KIND, STRICTNESS, MAP) \
block = joinBlockSemiImpl<KIND, STRICTNESS, MAP>(join_build_info, probe_process_info);
#define CALL(KIND, STRICTNESS, MAP) block = joinBlockSemiImpl<KIND, STRICTNESS, MAP>(probe_process_info);

using enum ASTTableJoin::Strictness;
using enum ASTTableJoin::Kind;
Expand All @@ -1750,166 +1732,93 @@ Block Join::joinBlockSemi(ProbeProcessInfo & probe_process_info) const

FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_join_prob_failpoint);

return removeUselessColumn(block);
}

namespace
{
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>
Block genSemiJoinResult(
ProbeProcessInfo & probe_process_info,
SemiJoinHelper<KIND, STRICTNESS, Maps> * helper,
const NameSet & output_column_names_set)
{
auto ret = helper->genJoinResult(output_column_names_set);
/// (left outer) (anti) semi join never expand the left block, just handle the whole block at one time is enough
probe_process_info.all_rows_joined_finish = true;

return removeUselessColumn(block);
probe_process_info.semi_join_family_helper.reset();
return ret;
}
} // namespace

template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>
Block Join::joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const ProbeProcessInfo & probe_process_info) const
Block Join::joinBlockSemiImpl(ProbeProcessInfo & probe_process_info) const
{
size_t rows = probe_process_info.block.rows();

auto [res, res_list] = JoinPartition::probeBlockSemi<KIND, STRICTNESS, Maps>(
partitions,
rows,
key_sizes,
collators,
join_build_info,
probe_process_info);

RUNTIME_ASSERT(res.size() == rows, "SemiJoinResult size {} must be equal to block size {}", res.size(), rows);
if (is_cancelled())
return {};

const NameSet & probe_output_name_set = has_other_condition
? output_columns_names_set_for_other_condition_after_finalize
: output_column_names_set_after_finalize;
Block block{};
for (size_t i = 0; i < probe_process_info.block.columns(); ++i)
{
const auto & column = probe_process_info.block.getByPosition(i);
if (probe_output_name_set.contains(column.name))
block.insert(column);
}

size_t left_columns = block.columns();
/// Add new columns to the block.
std::vector<size_t> right_column_indices_to_add;

for (size_t i = 0; i < right_sample_block.columns(); ++i)
JoinBuildInfo join_build_info{
enable_fine_grained_shuffle,
fine_grained_shuffle_count,
isEnableSpill(),
hash_join_spill_context->isSpilled(),
build_concurrency,
restore_config.restore_round};
if (probe_process_info.semi_join_family_helper != nullptr)
{
const auto & column = right_sample_block.getByPosition(i);
if (probe_output_name_set.contains(column.name))
{
RUNTIME_CHECK_MSG(
!block.has(column.name),
"block from probe side has a column with the same name: {} as a column in right_sample_block",
column.name);
block.insert(column);
right_column_indices_to_add.push_back(i);
}
// current block still not done
RUNTIME_CHECK_MSG(
!probe_process_info.all_rows_joined_finish,
"semi_family_helper should be reset to nullptr after all rows are joined for current block");
}

if constexpr (STRICTNESS == ASTTableJoin::Strictness::All)
else
{
if (!res_list.empty())
{
SemiJoinHelper<KIND, typename Maps::MappedType> helper(
block,
left_columns,
right_column_indices_to_add,
max_block_size,
non_equal_conditions,
is_cancelled);

helper.joinResult(res_list);
// probe a new block
probe_process_info.prepareForHashProbe(
key_names_left,
non_equal_conditions.left_filter_column,
kind,
strictness,
join_build_info.needVirtualDispatchForProbeBlock(),
collators,
restore_config.restore_round);

RUNTIME_CHECK_MSG(res_list.empty(), "SemiJoinResult list must be empty after calculating join result");
}
probe_process_info.semi_join_family_helper = decltype(probe_process_info.semi_join_family_helper)(
new SemiJoinHelper<KIND, STRICTNESS, Maps>(rows, max_block_size, non_equal_conditions, is_cancelled),
[](void * ptr) { delete reinterpret_cast<SemiJoinHelper<KIND, STRICTNESS, Maps> *>(ptr); });
}

if (is_cancelled())
return {};
auto * helper
= reinterpret_cast<SemiJoinHelper<KIND, STRICTNESS, Maps> *>(probe_process_info.semi_join_family_helper.get());

/// Now all results are known.

std::unique_ptr<IColumn::Filter> filter;
if constexpr (KIND == ASTTableJoin::Kind::Semi || KIND == ASTTableJoin::Kind::Anti)
filter = std::make_unique<IColumn::Filter>(rows);

MutableColumnPtr left_semi_column_ptr = nullptr;
ColumnInt8::Container * left_semi_column_data = nullptr;
ColumnUInt8::Container * left_semi_null_map = nullptr;

if constexpr (KIND == ASTTableJoin::Kind::LeftOuterSemi || KIND == ASTTableJoin::Kind::LeftOuterAnti)
if (!helper->isProbeHashTableDone())
{
left_semi_column_ptr = block.getByPosition(block.columns() - 1).column->cloneEmpty();
auto * left_semi_column = typeid_cast<ColumnNullable *>(left_semi_column_ptr.get());
left_semi_column_data = &typeid_cast<ColumnVector<Int8> &>(left_semi_column->getNestedColumn()).getData();
left_semi_column_data->reserve(rows);
left_semi_null_map = &left_semi_column->getNullMapColumn().getData();
if constexpr (STRICTNESS == ASTTableJoin::Strictness::Any)
{
left_semi_null_map->resize_fill(rows, 0);
}
else
{
left_semi_null_map->reserve(rows);
}
}
const NameSet & probe_output_name_set = has_other_condition
? output_columns_names_set_for_other_condition_after_finalize
: output_column_names_set_after_finalize;

size_t rows_for_semi_anti = 0;
for (size_t i = 0; i < rows; ++i)
{
auto result = res[i].getResult();
if constexpr (KIND == ASTTableJoin::Kind::Semi || KIND == ASTTableJoin::Kind::Anti)
{
if (isTrueSemiJoinResult(result))
{
// If the result is true, this row should be kept.
(*filter)[i] = 1;
++rows_for_semi_anti;
}
else
{
// If the result is null or false, this row should be filtered.
(*filter)[i] = 0;
}
}
else
{
if constexpr (STRICTNESS == ASTTableJoin::Strictness::Any)
{
left_semi_column_data->push_back(result);
}
else
{
switch (result)
{
case SemiJoinResultType::FALSE_VALUE:
left_semi_column_data->push_back(0);
left_semi_null_map->push_back(0);
break;
case SemiJoinResultType::TRUE_VALUE:
left_semi_column_data->push_back(1);
left_semi_null_map->push_back(0);
break;
case SemiJoinResultType::NULL_VALUE:
left_semi_column_data->push_back(0);
left_semi_null_map->push_back(1);
break;
}
}
}
helper->probeHashTable(
partitions,
key_sizes,
collators,
join_build_info,
probe_process_info,
probe_output_name_set,
right_sample_block);
}

if constexpr (KIND == ASTTableJoin::Kind::LeftOuterSemi || KIND == ASTTableJoin::Kind::LeftOuterAnti)
while (!helper->isJoinDone())
{
block.getByPosition(block.columns() - 1).column = std::move(left_semi_column_ptr);
if (is_cancelled())
return {};
helper->doJoin();
}

if constexpr (KIND == ASTTableJoin::Kind::Semi || KIND == ASTTableJoin::Kind::Anti)
{
for (size_t i = 0; i < left_columns; ++i)
{
auto & column = block.getByPosition(i);
if (output_column_names_set_after_finalize.contains(column.name))
column.column = column.column->filter(*filter, rows_for_semi_anti);
}
}
return block;
if (is_cancelled())
return {};

return genSemiJoinResult<KIND, STRICTNESS, Maps>(
probe_process_info,
helper,
output_column_names_set_after_finalize);
}

void Join::checkTypesOfKeys(const Block & block_left, const Block & block_right) const
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Interpreters/Join.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ class Join
Block joinBlockNullAwareSemiImpl(const ProbeProcessInfo & probe_process_info) const;

template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>
Block joinBlockSemiImpl(const JoinBuildInfo & join_build_info, const ProbeProcessInfo & probe_process_info) const;
Block joinBlockSemiImpl(ProbeProcessInfo & probe_process_info) const;

IColumn::Selector hashToSelector(const WeakHash32 & hash) const;
IColumn::Selector selectDispatchBlock(const Strings & key_columns_names, const Block & from_block);
Expand Down
3 changes: 3 additions & 0 deletions dbms/src/Interpreters/ProbeProcessInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ struct ProbeProcessInfo
/// for null-aware join
std::unique_ptr<NullAwareJoinProbeProcessData> null_aware_join_data;

// for semi join family
std::unique_ptr<void, std::function<void(void *)>> semi_join_family_helper = nullptr; /// type erasure

ProbeProcessInfo(UInt64 max_block_size_, UInt64 cache_columns_threshold_)
: partition_index(0)
, max_block_size(max_block_size_)
Expand Down
Loading

0 comments on commit 3f0df40

Please sign in to comment.