From a50413422717a0297d42ac79112c0a1919575df3 Mon Sep 17 00:00:00 2001 From: Masha Basmanova Date: Thu, 7 Dec 2023 22:47:09 -0800 Subject: [PATCH] Extract HashTable::prepareForJoinProbe method (#7926) Summary: Rename HashTable::prepareForProbe to prepareForGroupProbe and add prepareForJoinProbe method. Pull Request resolved: https://github.com/facebookincubator/velox/pull/7926 Reviewed By: xiaoxmeng Differential Revision: D51968271 Pulled By: mbasmanova fbshipit-source-id: f7a29e92331b610ceccd97fb352b4ee3aad20f1e --- velox/exec/GroupingSet.cpp | 2 +- velox/exec/HashProbe.cpp | 28 ++++-------------- velox/exec/HashProbe.h | 2 -- velox/exec/HashTable.cpp | 56 ++++++++++++++++++++++++++++++----- velox/exec/HashTable.h | 57 +++++++++++++++++++++++++++++------- velox/exec/RowNumber.cpp | 6 ++-- velox/exec/TopNRowNumber.cpp | 2 +- 7 files changed, 105 insertions(+), 48 deletions(-) diff --git a/velox/exec/GroupingSet.cpp b/velox/exec/GroupingSet.cpp index 5ec3548af7ab..216aa9b1f25b 100644 --- a/velox/exec/GroupingSet.cpp +++ b/velox/exec/GroupingSet.cpp @@ -247,7 +247,7 @@ void GroupingSet::addInputForActiveRows( TestValue::adjust( "facebook::velox::exec::GroupingSet::addInputForActiveRows", this); - table_->prepareForProbe(*lookup_, input, activeRows_, ignoreNullKeys_); + table_->prepareForGroupProbe(*lookup_, input, activeRows_, ignoreNullKeys_); table_->groupProbe(*lookup_); masks_.addInput(input, activeRows_); diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index d06f6027d795..eb2df3fa7d85 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -528,7 +528,9 @@ void HashProbe::addInput(RowVectorPtr input) { } input_ = std::move(input); - if (input_->size() > 0) { + const auto numInput = input_->size(); + + if (numInput > 0) { noInput_ = false; } @@ -577,28 +579,9 @@ void HashProbe::addInput(RowVectorPtr input) { if (!hasDecoded) { decodeAndDetectNonNullKeys(); } - activeRows_ = nonNullInputRows_; - lookup_->hashes.resize(input_->size()); - auto mode = table_->hashMode(); - auto& buildHashers = table_->hashers(); - for (auto i = 0; i < keyChannels_.size(); ++i) { - if (mode != BaseHashTable::HashMode::kHash) { - auto key = input_->childAt(keyChannels_[i]); - buildHashers[i]->lookupValueIds( - *key, activeRows_, scratchMemory_, lookup_->hashes); - } else { - hashers_[i]->hash(activeRows_, i > 0, lookup_->hashes); - } - } - lookup_->rows.clear(); - if (activeRows_.isAllSelected()) { - lookup_->rows.resize(activeRows_.size()); - std::iota(lookup_->rows.begin(), lookup_->rows.end(), 0); - } else { - activeRows_.applyToSelected( - [&](auto row) { lookup_->rows.push_back(row); }); - } + + table_->prepareForJoinProbe(*lookup_.get(), input_, activeRows_, false); passingInputRowsInitialized_ = false; if (isLeftJoin(joinType_) || isFullJoin(joinType_) || isAntiJoin(joinType_) || @@ -607,7 +590,6 @@ void HashProbe::addInput(RowVectorPtr input) { // including rows without a match in the output. Also, make sure to // initialize all 'hits' to nullptr as HashTable::joinProbe will only // process activeRows_. - auto numInput = input_->size(); auto& hits = lookup_->hits; hits.resize(numInput); std::fill(hits.data(), hits.data() + numInput, nullptr); diff --git a/velox/exec/HashProbe.h b/velox/exec/HashProbe.h index 7beadfb8a1e5..afa021879705 100644 --- a/velox/exec/HashProbe.h +++ b/velox/exec/HashProbe.h @@ -316,8 +316,6 @@ class HashProbe : public Operator { // side. Used by right semi project join. bool probeSideHasNullKeys_{false}; - VectorHasher::ScratchMemory scratchMemory_; - // Rows in 'filterInput_' to apply 'filter_' to. SelectivityVector filterInputRows_; diff --git a/velox/exec/HashTable.cpp b/velox/exec/HashTable.cpp index 81168a15d8b4..0fa9218b4414 100644 --- a/velox/exec/HashTable.cpp +++ b/velox/exec/HashTable.cpp @@ -1947,7 +1947,20 @@ void HashTable::checkConsistency() const { template class HashTable; template class HashTable; -void BaseHashTable::prepareForProbe( +namespace { +void populateLookupRows( + const SelectivityVector& rows, + raw_vector& lookupRows) { + if (rows.isAllSelected()) { + std::iota(lookupRows.begin(), lookupRows.end(), 0); + } else { + lookupRows.clear(); + rows.applyToSelected([&](auto row) { lookupRows.push_back(row); }); + } +} +} // namespace + +void BaseHashTable::prepareForGroupProbe( HashLookup& lookup, const RowVectorPtr& input, SelectivityVector& rows, @@ -1984,17 +1997,46 @@ void BaseHashTable::prepareForProbe( decideHashMode(input->size()); // Do not forward 'ignoreNullKeys' to avoid redundant evaluation of // deselectRowsWithNulls. - prepareForProbe(lookup, input, rows, false); + prepareForGroupProbe(lookup, input, rows, false); return; } } - if (rows.isAllSelected()) { - std::iota(lookup.rows.begin(), lookup.rows.end(), 0); - } else { - lookup.rows.clear(); - rows.applyToSelected([&](auto row) { lookup.rows.push_back(row); }); + populateLookupRows(rows, lookup.rows); +} + +void BaseHashTable::prepareForJoinProbe( + HashLookup& lookup, + const RowVectorPtr& input, + SelectivityVector& rows, + bool decodeAndRemoveNulls) { + auto& hashers = lookup.hashers; + + if (decodeAndRemoveNulls) { + for (auto& hasher : hashers) { + auto key = input->childAt(hasher->channel())->loadedVector(); + hasher->decode(*key, rows); + } + + // A null in any of the keys disables the row. + deselectRowsWithNulls(hashers, rows); + } + + lookup.reset(rows.end()); + + const auto mode = hashMode(); + for (auto i = 0; i < hashers.size(); ++i) { + auto& hasher = hashers[i]; + if (mode != BaseHashTable::HashMode::kHash) { + auto& key = input->childAt(hasher->channel()); + hashers_[i]->lookupValueIds( + *key, rows, lookup.scratchMemory, lookup.hashes); + } else { + hasher->hash(rows, i > 0, lookup.hashes); + } } + + populateLookupRows(rows, lookup.rows); } } // namespace facebook::velox::exec diff --git a/velox/exec/HashTable.h b/velox/exec/HashTable.h index bf8c43deb513..86dd236d1041 100644 --- a/velox/exec/HashTable.h +++ b/velox/exec/HashTable.h @@ -25,6 +25,7 @@ namespace facebook::velox::exec { using PartitionBoundIndexType = int64_t; +/// Contains input and output parameters for groupProbe and joinProbe APIs. struct HashLookup { explicit HashLookup(const std::vector>& h) : hashers(h) {} @@ -36,17 +37,36 @@ struct HashLookup { newGroups.clear(); } - // One entry per aggregation or join key + /// One entry per group-by or join key. const std::vector>& hashers; + + /// Scratch memory used to call VectorHasher::lookupValueIds. + VectorHasher::ScratchMemory scratchMemory; + + /// Input to groupProbe and joinProbe APIs. + + /// Set of row numbers of row to probe. raw_vector rows; - // Hash number for all input rows. + + /// Hashes or value IDs for rows in 'rows'. Not aligned with 'rows'. Index is + /// the row number. raw_vector hashes; - // If using valueIds, list of concatenated valueIds. 1:1 with 'hashes'. - raw_vector normalizedKeys; - // Hit for each row of input corresponding group row or join row. + + /// Results of groupProbe and joinProbe APIs. + + /// Contains one entry for each row in 'rows'. Index is the row number. + /// For groupProbe, a pointer to an existing or new row with matching grouping + /// keys. For joinProbe, a pointer to the first row with matching keys or null + /// if no match. raw_vector hits; - // Indices of newly inserted rows (not found during probe). + + /// For groupProbe, row numbers for which a new entry was inserted (didn't + /// exist before the groupProbe). Empty for joinProbe. std::vector newGroups; + + /// If using valueIds, list of concatenated valueIds. 1:1 with 'hashes'. + /// Populated by groupProbe and joinProbe. + raw_vector normalizedKeys; }; struct HashTableStats { @@ -124,7 +144,12 @@ class BaseHashTable { virtual HashStringAllocator* stringAllocator() = 0; - void prepareForProbe( + /// Populates 'hashes' and 'rows' fields in 'lookup' in preparation for + /// 'groupProbe' call. Rehashes the table if necessary. Uses lookup.hashes to + /// decode grouping keys from 'input'. If 'ignoreNullKeys' is true, updates + /// 'rows' to remove entries with null grouping keys. After this call, 'rows' + /// may have no entries selected. + void prepareForGroupProbe( HashLookup& lookup, const RowVectorPtr& input, SelectivityVector& rows, @@ -139,6 +164,20 @@ class BaseHashTable { /// join probe. Use listJoinResults to iterate over the results. virtual void joinProbe(HashLookup& lookup) = 0; + /// Populates 'hashes' and 'rows' fields in 'lookup' in preparation for + /// 'joinProbe' call. If hash mode is not kHash, populates 'hashes' with + /// values IDs. Rows which do not have value IDs are removed from 'rows' + /// (these rows cannot possibly match). if 'decodeAndRemoveNulls' is true, + /// uses lookup.hashes to decode grouping keys from 'input' and updates 'rows' + /// to remove entries with null grouping keys. Otherwise, assumes the caller + /// has done that already. After this call, 'rows' may have no entries + /// selected. + void prepareForJoinProbe( + HashLookup& lookup, + const RowVectorPtr& input, + SelectivityVector& rows, + bool decodeAndRemoveNulls); + /// Fills 'hits' with consecutive hash join results. The corresponding element /// of 'inputRows' is set to the corresponding row number in probe keys. /// Returns the number of hits produced. If this s less than hits.size() then @@ -256,10 +295,6 @@ class BaseHashTable { return rows_.get(); } - std::unique_ptr moveRows() { - return std::move(rows_); - } - // Static functions for processing internals. Public because used in // structs that define probe and insert algorithms. diff --git a/velox/exec/RowNumber.cpp b/velox/exec/RowNumber.cpp index 1d2b8cd60e43..3a103fcf32e6 100644 --- a/velox/exec/RowNumber.cpp +++ b/velox/exec/RowNumber.cpp @@ -78,7 +78,7 @@ void RowNumber::addInput(RowVectorPtr input) { } SelectivityVector rows(numInput); - table_->prepareForProbe(*lookup_, input, rows, false); + table_->prepareForGroupProbe(*lookup_, input, rows, false); table_->groupProbe(*lookup_); // Initialize new partitions with zeros. @@ -93,7 +93,7 @@ void RowNumber::addInput(RowVectorPtr input) { void RowNumber::addSpillInput() { const auto numInput = input_->size(); SelectivityVector rows(numInput); - table_->prepareForProbe(*lookup_, input_, rows, false); + table_->prepareForGroupProbe(*lookup_, input_, rows, false); table_->groupProbe(*lookup_); // Initialize new partitions with zeros. @@ -157,7 +157,7 @@ void RowNumber::restoreNextSpillPartition() { const auto numInput = input->size(); SelectivityVector rows(numInput); - table_->prepareForProbe(*lookup_, input, rows, false); + table_->prepareForGroupProbe(*lookup_, input, rows, false); table_->groupProbe(*lookup_); auto* counts = data->children().back()->as>(); diff --git a/velox/exec/TopNRowNumber.cpp b/velox/exec/TopNRowNumber.cpp index 5bb0791fbd2e..6c2171d3e40b 100644 --- a/velox/exec/TopNRowNumber.cpp +++ b/velox/exec/TopNRowNumber.cpp @@ -191,7 +191,7 @@ void TopNRowNumber::addInput(RowVectorPtr input) { ensureInputFits(input); SelectivityVector rows(numInput); - table_->prepareForProbe(*lookup_, input, rows, false); + table_->prepareForGroupProbe(*lookup_, input, rows, false); table_->groupProbe(*lookup_); // Initialize new partitions.