From 999018c58552716533223b9452de965904bd61f6 Mon Sep 17 00:00:00 2001 From: guo-shaoge Date: Fri, 17 Jan 2025 16:31:58 +0800 Subject: [PATCH] Aggregator support prefetch (#9679) close pingcap/tiflash#9680 Signed-off-by: guo-shaoge Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- .../AggregateFunctionGroupUniqArray.h | 6 +- .../src/AggregateFunctions/KeyHolderHelpers.h | 4 +- dbms/src/Common/ColumnsHashing.h | 101 +++-- dbms/src/Common/ColumnsHashingImpl.h | 95 ++++- dbms/src/Common/FailPoint.cpp | 1 + dbms/src/Common/HashTable/FixedHashTable.h | 5 +- dbms/src/Common/HashTable/Hash.h | 38 +- dbms/src/Common/HashTable/HashTable.h | 8 + .../src/Common/HashTable/HashTableKeyHolder.h | 12 +- dbms/src/Common/HashTable/SmallTable.h | 4 + dbms/src/Common/HashTable/StringHashMap.h | 32 +- dbms/src/Common/HashTable/StringHashTable.h | 40 +- dbms/src/Common/HashTable/TwoLevelHashTable.h | 16 + .../HashTable/TwoLevelStringHashTable.h | 25 +- .../tests/gtest_aggregation_executor.cpp | 136 ++++--- dbms/src/Flash/tests/gtest_compute_server.cpp | 3 + .../Flash/tests/gtest_spill_aggregation.cpp | 45 ++- dbms/src/Interpreters/Aggregator.cpp | 357 +++++++++++++----- dbms/src/Interpreters/Aggregator.h | 21 +- dbms/src/Interpreters/JoinPartition.cpp | 28 +- dbms/src/Interpreters/SetVariants.h | 4 +- libs/libcommon/include/common/StringRef.h | 2 +- 22 files changed, 655 insertions(+), 328 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h index 06dd57edf66..d3cbea74195 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupUniqArray.h @@ -182,18 +182,18 @@ class AggregateFunctionGroupUniqArrayGeneric { // We have to copy the keys to our arena. assert(arena != nullptr); - cur_set.emplace(ArenaKeyHolder{rhs_elem.getValue(), *arena}, it, inserted); + cur_set.emplace(ArenaKeyHolder{rhs_elem.getValue(), arena}, it, inserted); } } void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena *) const override { - ColumnArray & arr_to = assert_cast(to); + auto & arr_to = assert_cast(to); ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); IColumn & data_to = arr_to.getData(); auto & set = this->data(place).value; - offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + set.size()); + offsets_to.push_back((offsets_to.empty() ? 0 : offsets_to.back()) + set.size()); for (auto & elem : set) deserializeAndInsert(elem.getValue(), data_to); diff --git a/dbms/src/AggregateFunctions/KeyHolderHelpers.h b/dbms/src/AggregateFunctions/KeyHolderHelpers.h index 6677866f0d3..5c3a617a1cb 100644 --- a/dbms/src/AggregateFunctions/KeyHolderHelpers.h +++ b/dbms/src/AggregateFunctions/KeyHolderHelpers.h @@ -24,14 +24,14 @@ inline auto getKeyHolder(const IColumn & column, size_t row_num, Arena & arena) { if constexpr (is_plain_column) { - return ArenaKeyHolder{column.getDataAt(row_num), arena}; + return ArenaKeyHolder{column.getDataAt(row_num), &arena}; } else { const char * begin = nullptr; StringRef serialized = column.serializeValueIntoArena(row_num, arena, begin); assert(serialized.data != nullptr); - return SerializedKeyHolder{serialized, arena}; + return SerializedKeyHolder{serialized, &arena}; } } diff --git a/dbms/src/Common/ColumnsHashing.h b/dbms/src/Common/ColumnsHashing.h index 398d6605e60..a03136bbed8 100644 --- a/dbms/src/Common/ColumnsHashing.h +++ b/dbms/src/Common/ColumnsHashing.h @@ -47,6 +47,9 @@ struct HashMethodOneNumber { using Self = HashMethodOneNumber; using Base = columns_hashing_impl::HashMethodBase; + using KeyHolderType = FieldType; + + static constexpr bool is_serialized_key = false; const FieldType * vec; @@ -73,7 +76,7 @@ struct HashMethodOneNumber using Base::getHash; /// (const Data & data, size_t row, Arena & pool) -> size_t /// Is used for default implementation in HashMethodBase. - ALWAYS_INLINE inline FieldType getKeyHolder(size_t row, Arena *, std::vector &) const + ALWAYS_INLINE inline KeyHolderType getKeyHolder(size_t row, Arena *, std::vector &) const { if constexpr (std::is_same_v) return vec[row]; @@ -86,13 +89,15 @@ struct HashMethodOneNumber /// For the case when there is one string key. -template +template struct HashMethodString - : public columns_hashing_impl:: - HashMethodBase, Value, Mapped, use_cache> + : public columns_hashing_impl::HashMethodBase, Value, Mapped, use_cache> { - using Self = HashMethodString; + using Self = HashMethodString; using Base = columns_hashing_impl::HashMethodBase; + using KeyHolderType = ArenaKeyHolder; + + static constexpr bool is_serialized_key = false; const IColumn::Offset * offsets; const UInt8 * chars; @@ -108,14 +113,10 @@ struct HashMethodString offsets = column_string.getOffsets().data(); chars = column_string.getChars().data(); if (!collators.empty()) - { - if constexpr (!place_string_to_arena) - throw Exception("String with collator must be placed on arena.", ErrorCodes::LOGICAL_ERROR); collator = collators[0]; - } } - ALWAYS_INLINE inline auto getKeyHolder( + ALWAYS_INLINE inline KeyHolderType getKeyHolder( ssize_t row, [[maybe_unused]] Arena * pool, std::vector & sort_key_containers) const @@ -123,17 +124,10 @@ struct HashMethodString auto last_offset = row == 0 ? 0 : offsets[row - 1]; // Remove last zero byte. StringRef key(chars + last_offset, offsets[row] - last_offset - 1); + if (likely(collator)) + key = collator->sortKey(key.data, key.size, sort_key_containers[0]); - if constexpr (place_string_to_arena) - { - if (likely(collator)) - key = collator->sortKey(key.data, key.size, sort_key_containers[0]); - return ArenaKeyHolder{key, *pool}; - } - else - { - return key; - } + return ArenaKeyHolder{key, pool}; } protected: @@ -146,6 +140,9 @@ struct HashMethodStringBin { using Self = HashMethodStringBin; using Base = columns_hashing_impl::HashMethodBase; + using KeyHolderType = ArenaKeyHolder; + + static constexpr bool is_serialized_key = false; const IColumn::Offset * offsets; const UInt8 * chars; @@ -158,12 +155,12 @@ struct HashMethodStringBin chars = column_string.getChars().data(); } - ALWAYS_INLINE inline auto getKeyHolder(ssize_t row, Arena * pool, std::vector &) const + ALWAYS_INLINE inline KeyHolderType getKeyHolder(ssize_t row, Arena * pool, std::vector &) const { auto last_offset = row == 0 ? 0 : offsets[row - 1]; StringRef key(chars + last_offset, offsets[row] - last_offset - 1); key = BinCollatorSortKey(key.data, key.size); - return ArenaKeyHolder{key, *pool}; + return ArenaKeyHolder{key, pool}; } protected: @@ -343,6 +340,9 @@ struct HashMethodFastPathTwoKeysSerialized { using Self = HashMethodFastPathTwoKeysSerialized; using Base = columns_hashing_impl::HashMethodBase; + using KeyHolderType = SerializedKeyHolder; + + static constexpr bool is_serialized_key = true; Key1Desc key_1_desc; Key2Desc key_2_desc; @@ -352,13 +352,13 @@ struct HashMethodFastPathTwoKeysSerialized , key_2_desc(key_columns[1]) {} - ALWAYS_INLINE inline auto getKeyHolder(ssize_t row, Arena * pool, std::vector &) const + ALWAYS_INLINE inline KeyHolderType getKeyHolder(ssize_t row, Arena * pool, std::vector &) const { StringRef key1; StringRef key2; size_t alloc_size = key_1_desc.getKey(row, key1) + key_2_desc.getKey(row, key2); char * start = pool->alloc(alloc_size); - SerializedKeyHolder ret{{start, alloc_size}, *pool}; + SerializedKeyHolder ret{{start, alloc_size}, pool}; Key1Desc::serializeKey(start, key1); Key2Desc::serializeKey(start, key2); return ret; @@ -370,16 +370,16 @@ struct HashMethodFastPathTwoKeysSerialized /// For the case when there is one fixed-length string key. -template +template struct HashMethodFixedString - : public columns_hashing_impl::HashMethodBase< - HashMethodFixedString, - Value, - Mapped, - use_cache> + : public columns_hashing_impl:: + HashMethodBase, Value, Mapped, use_cache> { - using Self = HashMethodFixedString; + using Self = HashMethodFixedString; using Base = columns_hashing_impl::HashMethodBase; + using KeyHolderType = ArenaKeyHolder; + + static constexpr bool is_serialized_key = false; size_t n; const ColumnFixedString::Chars_t * chars; @@ -398,26 +398,14 @@ struct HashMethodFixedString collator = collators[0]; } - ALWAYS_INLINE inline auto getKeyHolder( - size_t row, - [[maybe_unused]] Arena * pool, - std::vector & sort_key_containers) const + ALWAYS_INLINE inline KeyHolderType getKeyHolder(size_t row, Arena * pool, std::vector & sort_key_containers) + const { StringRef key(&(*chars)[row * n], n); - if (collator) - { key = collator->sortKeyFastPath(key.data, key.size, sort_key_containers[0]); - } - if constexpr (place_string_to_arena) - { - return ArenaKeyHolder{key, *pool}; - } - else - { - return key; - } + return ArenaKeyHolder{key, pool}; } protected: @@ -437,7 +425,9 @@ struct HashMethodKeysFixed using Self = HashMethodKeysFixed; using BaseHashed = columns_hashing_impl::HashMethodBase; using Base = columns_hashing_impl::BaseStateKeysFixed; + using KeyHolderType = Key; + static constexpr bool is_serialized_key = false; static constexpr bool has_nullable_keys = has_nullable_keys_; Sizes key_sizes; @@ -526,7 +516,7 @@ struct HashMethodKeysFixed #endif } - ALWAYS_INLINE inline Key getKeyHolder(size_t row, Arena *, std::vector &) const + ALWAYS_INLINE inline KeyHolderType getKeyHolder(size_t row, Arena *, std::vector &) const { if constexpr (has_nullable_keys) { @@ -592,6 +582,9 @@ struct HashMethodSerialized { using Self = HashMethodSerialized; using Base = columns_hashing_impl::HashMethodBase; + using KeyHolderType = SerializedKeyHolder; + + static constexpr bool is_serialized_key = true; ColumnRawPtrs key_columns; size_t keys_size; @@ -606,14 +599,12 @@ struct HashMethodSerialized , collators(collators_) {} - ALWAYS_INLINE inline SerializedKeyHolder getKeyHolder( - size_t row, - Arena * pool, - std::vector & sort_key_containers) const + ALWAYS_INLINE inline KeyHolderType getKeyHolder(size_t row, Arena * pool, std::vector & sort_key_containers) + const { return SerializedKeyHolder{ serializeKeysToPoolContiguous(row, keys_size, key_columns, collators, sort_key_containers, *pool), - *pool}; + pool}; } protected: @@ -628,6 +619,9 @@ struct HashMethodHashed using Key = UInt128; using Self = HashMethodHashed; using Base = columns_hashing_impl::HashMethodBase; + using KeyHolderType = Key; + + static constexpr bool is_serialized_key = false; ColumnRawPtrs key_columns; TiDB::TiDBCollators collators; @@ -637,7 +631,8 @@ struct HashMethodHashed , collators(collators_) {} - ALWAYS_INLINE inline Key getKeyHolder(size_t row, Arena *, std::vector & sort_key_containers) const + ALWAYS_INLINE inline KeyHolderType getKeyHolder(size_t row, Arena *, std::vector & sort_key_containers) + const { return hash128(row, key_columns.size(), key_columns, collators, sort_key_containers); } diff --git a/dbms/src/Common/ColumnsHashingImpl.h b/dbms/src/Common/ColumnsHashingImpl.h index d4f4143015d..bf130d2bd29 100644 --- a/dbms/src/Common/ColumnsHashingImpl.h +++ b/dbms/src/Common/ColumnsHashingImpl.h @@ -59,24 +59,24 @@ struct LastElementCache template class EmplaceResultImpl { - Mapped & value; - Mapped & cached_value; - bool inserted; + Mapped * value = nullptr; + Mapped * cached_value = nullptr; + bool inserted = false; public: EmplaceResultImpl(Mapped & value_, Mapped & cached_value_, bool inserted_) - : value(value_) - , cached_value(cached_value_) + : value(&value_) + , cached_value(&cached_value_) , inserted(inserted_) {} bool isInserted() const { return inserted; } - auto & getMapped() const { return value; } + auto & getMapped() const { return *value; } void setMapped(const Mapped & mapped) { - cached_value = mapped; - value = mapped; + *cached_value = mapped; + *value = mapped; } }; @@ -119,7 +119,7 @@ class FindResultImpl bool isFound() const { return found; } }; -template +template class HashMethodBase { public: @@ -127,6 +127,7 @@ class HashMethodBase using FindResult = FindResultImpl; static constexpr bool has_mapped = !std::is_same::value; using Cache = LastElementCache; + using Derived = TDerived; template ALWAYS_INLINE inline EmplaceResult emplaceKey( @@ -139,6 +140,12 @@ class HashMethodBase return emplaceImpl(key_holder, data); } + template + ALWAYS_INLINE inline EmplaceResult emplaceKey(Data & data, KeyHolder && key_holder, size_t hashval) + { + return emplaceImpl(key_holder, data, hashval); + } + template ALWAYS_INLINE inline FindResult findKey( Data & data, @@ -150,14 +157,20 @@ class HashMethodBase return findKeyImpl(keyHolderGetKey(key_holder), data); } + template + ALWAYS_INLINE inline FindResult findKey(Data & data, KeyHolder && key_holder, size_t hashval) + { + return findKeyImpl(keyHolderGetKey(key_holder), data, hashval); + } + template ALWAYS_INLINE inline size_t getHash( const Data & data, size_t row, Arena & pool, - std::vector & sort_key_containers) + std::vector & sort_key_containers) const { - auto key_holder = static_cast(*this).getKeyHolder(row, &pool, sort_key_containers); + auto key_holder = static_cast(*this).getKeyHolder(row, &pool, sort_key_containers); return data.hash(keyHolderGetKey(key_holder)); } @@ -179,6 +192,28 @@ class HashMethodBase } } + template + ALWAYS_INLINE inline EmplaceResult emplaceImpl(KeyHolder & key_holder, Data & data, size_t hashval) + { + if constexpr (Cache::consecutive_keys_optimization) + { + if (cache.found && cache.check(keyHolderGetKey(key_holder))) + { + if constexpr (has_mapped) + return EmplaceResult(cache.value.second, cache.value.second, false); + else + return EmplaceResult(false); + } + } + + typename Data::LookupResult it; + bool inserted = false; + + data.emplace(key_holder, it, inserted, hashval); + + return handleEmplaceResult(it, inserted); + } + template ALWAYS_INLINE inline EmplaceResult emplaceImpl(KeyHolder & key_holder, Data & data) { @@ -195,8 +230,15 @@ class HashMethodBase typename Data::LookupResult it; bool inserted = false; + data.emplace(key_holder, it, inserted); + return handleEmplaceResult(it, inserted); + } + + template + ALWAYS_INLINE inline EmplaceResult handleEmplaceResult(typename Data::LookupResult & it, bool inserted) + { [[maybe_unused]] Mapped * cached = nullptr; if constexpr (has_mapped) cached = &it->getMapped(); @@ -233,7 +275,27 @@ class HashMethodBase } template - ALWAYS_INLINE inline FindResult findKeyImpl(Key key, Data & data) + ALWAYS_INLINE inline FindResult findKeyImpl(Key & key, Data & data) + { + if constexpr (Cache::consecutive_keys_optimization) + { + if (cache.check(key)) + { + if constexpr (has_mapped) + return FindResult(&cache.value.second, cache.found); + else + return FindResult(cache.found); + } + } + + typename Data::LookupResult it; + it = data.find(key); + + return handleFindResult(key, it); + } + + template + ALWAYS_INLINE inline FindResult findKeyImpl(Key & key, Data & data, size_t hashval) { if constexpr (Cache::consecutive_keys_optimization) { @@ -246,8 +308,15 @@ class HashMethodBase } } - auto it = data.find(key); + typename Data::LookupResult it; + it = data.find(key, hashval); + return handleFindResult(key, it); + } + + template + ALWAYS_INLINE inline FindResult handleFindResult(Key & key, typename Data::LookupResult & it) + { if constexpr (consecutive_keys_optimization) { cache.found = it != nullptr; diff --git a/dbms/src/Common/FailPoint.cpp b/dbms/src/Common/FailPoint.cpp index be684cf3751..49f7f97f5fc 100644 --- a/dbms/src/Common/FailPoint.cpp +++ b/dbms/src/Common/FailPoint.cpp @@ -114,6 +114,7 @@ namespace DB M(force_set_parallel_prehandle_threshold) \ M(force_raise_prehandle_exception) \ M(force_agg_on_partial_block) \ + M(force_agg_prefetch) \ M(force_set_fap_candidate_store_id) \ M(force_not_clean_fap_on_destroy) \ M(force_fap_worker_throw) \ diff --git a/dbms/src/Common/HashTable/FixedHashTable.h b/dbms/src/Common/HashTable/FixedHashTable.h index 259e90684fc..8b0b721aa8c 100644 --- a/dbms/src/Common/HashTable/FixedHashTable.h +++ b/dbms/src/Common/HashTable/FixedHashTable.h @@ -212,7 +212,6 @@ class FixedHashTable typename cell_type::CellExt cell; }; - public: using key_type = Key; using mapped_type = typename Cell::mapped_type; @@ -222,6 +221,8 @@ class FixedHashTable using LookupResult = Cell *; using ConstLookupResult = const Cell *; + static constexpr bool is_string_hash_map = false; + static constexpr bool is_two_level = false; size_t hash(const Key & x) const { return x; } @@ -352,6 +353,8 @@ class FixedHashTable iterator end() { return iterator(this, buf ? buf + NUM_CELLS : buf); } + inline void prefetch(size_t) {} + /// The last parameter is unused but exists for compatibility with HashTable interface. void ALWAYS_INLINE emplace(const Key & x, LookupResult & it, bool & inserted, size_t /* hash */ = 0) { diff --git a/dbms/src/Common/HashTable/Hash.h b/dbms/src/Common/HashTable/Hash.h index b4f5d2c0a04..457b4b9f3c0 100644 --- a/dbms/src/Common/HashTable/Hash.h +++ b/dbms/src/Common/HashTable/Hash.h @@ -130,8 +130,8 @@ inline DB::UInt64 wideIntHashCRC32(const T & x, DB::UInt64 updated_value) return updated_value; } static_assert( - DB::IsDecimal< - T> || is_boost_number_v || std::is_same_v || std::is_same_v || std::is_same_v); + DB::IsDecimal || is_boost_number_v || std::is_same_v || std::is_same_v + || std::is_same_v); __builtin_unreachable(); } @@ -244,8 +244,8 @@ inline size_t defaultHash64(const std::enable_if_t, T> & key return boost::multiprecision::hash_value(key); } static_assert( - is_boost_number_v< - T> || std::is_same_v || std::is_same_v || std::is_same_v); + is_boost_number_v || std::is_same_v || std::is_same_v + || std::is_same_v); __builtin_unreachable(); } @@ -297,20 +297,26 @@ inline size_t hashCRC32(const std::enable_if_t, T> & key) template struct HashCRC32; -#define DEFINE_HASH(T) \ - template <> \ - struct HashCRC32 \ - { \ - static_assert(is_fit_register); \ - size_t operator()(T key) const { return hashCRC32(key); } \ +#define DEFINE_HASH(T) \ + template <> \ + struct HashCRC32 \ + { \ + static_assert(is_fit_register); \ + size_t operator()(T key) const \ + { \ + return hashCRC32(key); \ + } \ }; -#define DEFINE_HASH_WIDE(T) \ - template <> \ - struct HashCRC32 \ - { \ - static_assert(!is_fit_register); \ - size_t operator()(const T & key) const { return hashCRC32(key); } \ +#define DEFINE_HASH_WIDE(T) \ + template <> \ + struct HashCRC32 \ + { \ + static_assert(!is_fit_register); \ + size_t operator()(const T & key) const \ + { \ + return hashCRC32(key); \ + } \ }; DEFINE_HASH(DB::UInt8) diff --git a/dbms/src/Common/HashTable/HashTable.h b/dbms/src/Common/HashTable/HashTable.h index a4f0fe3be03..c0f066edbb0 100644 --- a/dbms/src/Common/HashTable/HashTable.h +++ b/dbms/src/Common/HashTable/HashTable.h @@ -402,6 +402,9 @@ class HashTable using Grower = GrowerType; using Allocator = AllocatorType; + static constexpr bool is_string_hash_map = false; + static constexpr bool is_two_level = false; + protected: friend class const_iterator; friend class iterator; @@ -851,6 +854,11 @@ class HashTable iterator end() { return iterator(this, buf ? buf + grower.bufSize() : buf); } + void ALWAYS_INLINE prefetch(size_t hashval) const + { + const size_t place_value = grower.place(hashval); + __builtin_prefetch(static_cast(&buf[place_value])); + } protected: const_iterator iteratorTo(const Cell * ptr) const { return const_iterator(this, ptr); } diff --git a/dbms/src/Common/HashTable/HashTableKeyHolder.h b/dbms/src/Common/HashTable/HashTableKeyHolder.h index 01b06dce87d..2c81b050789 100644 --- a/dbms/src/Common/HashTable/HashTableKeyHolder.h +++ b/dbms/src/Common/HashTable/HashTableKeyHolder.h @@ -92,7 +92,7 @@ namespace DB struct ArenaKeyHolder { StringRef key; - Arena & pool; + Arena * pool; }; } // namespace DB @@ -111,14 +111,14 @@ inline void ALWAYS_INLINE keyHolderPersistKey(DB::ArenaKeyHolder & holder) { // Hash table shouldn't ask us to persist a zero key assert(holder.key.size > 0); - holder.key.data = holder.pool.insert(holder.key.data, holder.key.size); + holder.key.data = holder.pool->insert(holder.key.data, holder.key.size); } inline void ALWAYS_INLINE keyHolderPersistKey(DB::ArenaKeyHolder && holder) { // Hash table shouldn't ask us to persist a zero key assert(holder.key.size > 0); - holder.key.data = holder.pool.insert(holder.key.data, holder.key.size); + holder.key.data = holder.pool->insert(holder.key.data, holder.key.size); } inline void ALWAYS_INLINE keyHolderDiscardKey(DB::ArenaKeyHolder &) {} @@ -134,7 +134,7 @@ namespace DB struct SerializedKeyHolder { StringRef key; - Arena & pool; + Arena * pool; }; } // namespace DB @@ -157,7 +157,7 @@ inline void ALWAYS_INLINE keyHolderDiscardKey(DB::SerializedKeyHolder & holder) { //[[maybe_unused]] void * new_head = holder.pool.rollback(holder.key.size); //assert(new_head == holder.key.data); - holder.pool.rollback(holder.key.size); + holder.pool->rollback(holder.key.size); holder.key.data = nullptr; holder.key.size = 0; } @@ -166,7 +166,7 @@ inline void ALWAYS_INLINE keyHolderDiscardKey(DB::SerializedKeyHolder && holder) { //[[maybe_unused]] void * new_head = holder.pool.rollback(holder.key.size); //assert(new_head == holder.key.data); - holder.pool.rollback(holder.key.size); + holder.pool->rollback(holder.key.size); holder.key.data = nullptr; holder.key.size = 0; } diff --git a/dbms/src/Common/HashTable/SmallTable.h b/dbms/src/Common/HashTable/SmallTable.h index fa40b479430..1292a4205da 100644 --- a/dbms/src/Common/HashTable/SmallTable.h +++ b/dbms/src/Common/HashTable/SmallTable.h @@ -85,6 +85,9 @@ class SmallTable using value_type = typename Cell::value_type; using cell_type = Cell; + static constexpr bool is_string_hash_map = false; + static constexpr bool is_two_level = false; + class Reader final : private Cell::State { public: @@ -296,6 +299,7 @@ class SmallTable iterator ALWAYS_INLINE find(Key x) { return iteratorTo(findCell(x)); } const_iterator ALWAYS_INLINE find(Key x) const { return iteratorTo(findCell(x)); } + void ALWAYS_INLINE prefetch(size_t) {} void write(DB::WriteBuffer & wb) const { diff --git a/dbms/src/Common/HashTable/StringHashMap.h b/dbms/src/Common/HashTable/StringHashMap.h index 6f7e668e1d9..a070f0ef0a9 100644 --- a/dbms/src/Common/HashTable/StringHashMap.h +++ b/dbms/src/Common/HashTable/StringHashMap.h @@ -90,31 +90,15 @@ struct StringHashMapCell template struct StringHashMapSubMaps { + using Hash = StringHashTableHash; using T0 = StringHashTableEmpty>; - using T1 = HashMapTable< - StringKey8, - StringHashMapCell, - StringHashTableHash, - StringHashTableGrower<>, - Allocator>; - using T2 = HashMapTable< - StringKey16, - StringHashMapCell, - StringHashTableHash, - StringHashTableGrower<>, - Allocator>; - using T3 = HashMapTable< - StringKey24, - StringHashMapCell, - StringHashTableHash, - StringHashTableGrower<>, - Allocator>; - using Ts = HashMapTable< - StringRef, - StringHashMapCell, - StringHashTableHash, - StringHashTableGrower<>, - Allocator>; + using T1 + = HashMapTable, Hash, StringHashTableGrower<>, Allocator>; + using T2 + = HashMapTable, Hash, StringHashTableGrower<>, Allocator>; + using T3 + = HashMapTable, Hash, StringHashTableGrower<>, Allocator>; + using Ts = HashMapTable, Hash, StringHashTableGrower<>, Allocator>; }; template diff --git a/dbms/src/Common/HashTable/StringHashTable.h b/dbms/src/Common/HashTable/StringHashTable.h index aa4825f171a..e1236caf381 100644 --- a/dbms/src/Common/HashTable/StringHashTable.h +++ b/dbms/src/Common/HashTable/StringHashTable.h @@ -16,11 +16,11 @@ #include #include +#include #include #include - using StringKey8 = UInt64; using StringKey16 = DB::UInt128; struct StringKey24 @@ -51,20 +51,20 @@ inline StringRef ALWAYS_INLINE toStringRef(const StringKey24 & n) struct StringHashTableHash { #if defined(__SSE4_2__) - size_t ALWAYS_INLINE operator()(StringKey8 key) const + static size_t ALWAYS_INLINE operator()(StringKey8 key) { size_t res = -1ULL; res = _mm_crc32_u64(res, key); return res; } - size_t ALWAYS_INLINE operator()(const StringKey16 & key) const + static size_t ALWAYS_INLINE operator()(const StringKey16 & key) { size_t res = -1ULL; res = _mm_crc32_u64(res, key.low); res = _mm_crc32_u64(res, key.high); return res; } - size_t ALWAYS_INLINE operator()(const StringKey24 & key) const + static size_t ALWAYS_INLINE operator()(const StringKey24 & key) { size_t res = -1ULL; res = _mm_crc32_u64(res, key.a); @@ -73,20 +73,20 @@ struct StringHashTableHash return res; } #else - size_t ALWAYS_INLINE operator()(StringKey8 key) const + static size_t ALWAYS_INLINE operator()(StringKey8 key) { return CityHash_v1_0_2::CityHash64(reinterpret_cast(&key), 8); } - size_t ALWAYS_INLINE operator()(const StringKey16 & key) const + static size_t ALWAYS_INLINE operator()(const StringKey16 & key) { return CityHash_v1_0_2::CityHash64(reinterpret_cast(&key), 16); } - size_t ALWAYS_INLINE operator()(const StringKey24 & key) const + static size_t ALWAYS_INLINE operator()(const StringKey24 & key) { return CityHash_v1_0_2::CityHash64(reinterpret_cast(&key), 24); } #endif - size_t ALWAYS_INLINE operator()(StringRef key) const { return StringRefHash()(key); } + static size_t ALWAYS_INLINE operator()(StringRef key) { return StringRefHash()(key); } }; template @@ -150,6 +150,7 @@ struct StringHashTableEmpty //-V730 return hasZero() ? zeroValue() : nullptr; } + ALWAYS_INLINE inline void prefetch(size_t) {} void write(DB::WriteBuffer & wb) const { zeroValue()->write(wb); } void writeText(DB::WriteBuffer & wb) const { zeroValue()->writeText(wb); } void read(DB::ReadBuffer & rb) { zeroValue()->read(rb); } @@ -157,6 +158,7 @@ struct StringHashTableEmpty //-V730 size_t size() const { return hasZero() ? 1 : 0; } bool empty() const { return !hasZero(); } size_t getBufferSizeInBytes() const { return sizeof(Cell); } + size_t getBufferSizeInCells() const { return 1; } void setResizeCallback(const ResizeCallback &) {} size_t getCollisions() const { return 0; } }; @@ -195,6 +197,8 @@ class StringHashTable : private boost::noncopyable { protected: static constexpr size_t NUM_MAPS = 5; + using Self = StringHashTable; + // Map for storing empty string using T0 = typename SubMaps::T0; @@ -205,7 +209,6 @@ class StringHashTable : private boost::noncopyable // Long strings are stored as StringRef along with saved hash using Ts = typename SubMaps::Ts; - using Self = StringHashTable; template friend class TwoLevelStringHashTable; @@ -226,6 +229,9 @@ class StringHashTable : private boost::noncopyable using LookupResult = StringHashTableLookupResult; using ConstLookupResult = StringHashTableLookupResult; + static constexpr bool is_string_hash_map = true; + static constexpr bool is_two_level = false; + StringHashTable() = default; explicit StringHashTable(size_t reserve_for_num_elements) @@ -257,7 +263,6 @@ class StringHashTable : private boost::noncopyable #endif dispatch(Self & self, KeyHolder && key_holder, Func && func) { - StringHashTableHash hash; const StringRef & x = keyHolderGetKey(key_holder); const size_t sz = x.size; if (sz == 0) @@ -270,7 +275,7 @@ class StringHashTable : private boost::noncopyable { // Strings with trailing zeros are not representable as fixed-size // string keys. Put them to the generic table. - return func(self.ms, std::forward(key_holder), hash(x)); + return func(self.ms, std::forward(key_holder), StringHashTableHash::operator()(x)); } const char * p = x.data; @@ -306,7 +311,7 @@ class StringHashTable : private boost::noncopyable n[0] <<= s; } keyHolderDiscardKey(key_holder); - return func(self.m1, k8, hash(k8)); + return func(self.m1, k8, StringHashTableHash::operator()(k8)); } case 1: // 9..16 bytes { @@ -318,7 +323,7 @@ class StringHashTable : private boost::noncopyable else n[1] <<= s; keyHolderDiscardKey(key_holder); - return func(self.m2, k16, hash(k16)); + return func(self.m2, k16, StringHashTableHash::operator()(k16)); } case 2: // 17..24 bytes { @@ -330,11 +335,11 @@ class StringHashTable : private boost::noncopyable else n[2] <<= s; keyHolderDiscardKey(key_holder); - return func(self.m3, k24, hash(k24)); + return func(self.m3, k24, StringHashTableHash::operator()(k24)); } default: // >= 25 bytes { - return func(self.ms, std::forward(key_holder), hash(x)); + return func(self.ms, std::forward(key_holder), StringHashTableHash::operator()(x)); } } } @@ -434,6 +439,11 @@ class StringHashTable : private boost::noncopyable bool empty() const { return m0.empty() && m1.empty() && m2.empty() && m3.empty() && ms.empty(); } + size_t getBufferSizeInCells() const + { + return m0.getBufferSizeInCells() + m1.getBufferSizeInCells() + m2.getBufferSizeInCells() + + m3.getBufferSizeInCells() + ms.getBufferSizeInCells(); + } size_t getBufferSizeInBytes() const { return m0.getBufferSizeInBytes() + m1.getBufferSizeInBytes() + m2.getBufferSizeInBytes() diff --git a/dbms/src/Common/HashTable/TwoLevelHashTable.h b/dbms/src/Common/HashTable/TwoLevelHashTable.h index 6778cd4a3e8..75a5402363d 100644 --- a/dbms/src/Common/HashTable/TwoLevelHashTable.h +++ b/dbms/src/Common/HashTable/TwoLevelHashTable.h @@ -60,6 +60,9 @@ class TwoLevelHashTable : private boost::noncopyable static constexpr size_t NUM_BUCKETS = 1ULL << BITS_FOR_BUCKET; static constexpr size_t MAX_BUCKET = NUM_BUCKETS - 1; + static constexpr bool is_string_hash_map = false; + static constexpr bool is_two_level = true; + size_t hash(const Key & x) const { return Hash::operator()(x); } /// NOTE Bad for hash tables with more than 2^32 cells. @@ -285,6 +288,12 @@ class TwoLevelHashTable : private boost::noncopyable impls[buck].emplace(key_holder, it, inserted, hash_value); } + void ALWAYS_INLINE prefetch(size_t hashval) const + { + size_t buck = getBucketFromHash(hashval); + impls[buck].prefetch(hashval); + } + LookupResult ALWAYS_INLINE find(Key x, size_t hash_value) { size_t buck = getBucketFromHash(hash_value); @@ -352,6 +361,13 @@ class TwoLevelHashTable : private boost::noncopyable return true; } + size_t getBufferSizeInCells() const + { + size_t res = 0; + for (const auto & impl : impls) + res += impl.getBufferSizeInCells(); + return res; + } size_t getBufferSizeInBytes() const { size_t res = 0; diff --git a/dbms/src/Common/HashTable/TwoLevelStringHashTable.h b/dbms/src/Common/HashTable/TwoLevelStringHashTable.h index 5bdb24a3d13..7659b5a73fb 100644 --- a/dbms/src/Common/HashTable/TwoLevelStringHashTable.h +++ b/dbms/src/Common/HashTable/TwoLevelStringHashTable.h @@ -30,6 +30,9 @@ class TwoLevelStringHashTable : private boost::noncopyable static constexpr size_t NUM_BUCKETS = 1ULL << BITS_FOR_BUCKET; static constexpr size_t MAX_BUCKET = NUM_BUCKETS - 1; + static constexpr bool is_string_hash_map = true; + static constexpr bool is_two_level = true; + // TODO: currently hashing contains redundant computations when doing distributed or external aggregations size_t hash(const Key & x) const { @@ -104,7 +107,6 @@ class TwoLevelStringHashTable : private boost::noncopyable #endif dispatch(Self & self, KeyHolder && key_holder, Func && func) { - StringHashTableHash hash; const StringRef & x = keyHolderGetKey(key_holder); const size_t sz = x.size; if (sz == 0) @@ -117,7 +119,7 @@ class TwoLevelStringHashTable : private boost::noncopyable { // Strings with trailing zeros are not representable as fixed-size // string keys. Put them to the generic table. - auto res = hash(x); + auto res = SubMaps::Hash::operator()(x); auto buck = getBucketFromHash(res); return func(self.impls[buck].ms, std::forward(key_holder), res); } @@ -154,7 +156,7 @@ class TwoLevelStringHashTable : private boost::noncopyable else n[0] <<= s; } - auto res = hash(k8); + auto res = SubMaps::Hash::operator()(k8); auto buck = getBucketFromHash(res); keyHolderDiscardKey(key_holder); return func(self.impls[buck].m1, k8, res); @@ -168,7 +170,7 @@ class TwoLevelStringHashTable : private boost::noncopyable n[1] >>= s; else n[1] <<= s; - auto res = hash(k16); + auto res = SubMaps::Hash::operator()(k16); auto buck = getBucketFromHash(res); keyHolderDiscardKey(key_holder); return func(self.impls[buck].m2, k16, res); @@ -182,14 +184,14 @@ class TwoLevelStringHashTable : private boost::noncopyable n[2] >>= s; else n[2] <<= s; - auto res = hash(k24); + auto res = SubMaps::Hash::operator()(k24); auto buck = getBucketFromHash(res); keyHolderDiscardKey(key_holder); return func(self.impls[buck].m3, k24, res); } default: { - auto res = hash(x); + auto res = SubMaps::Hash::operator()(x); auto buck = getBucketFromHash(res); return func(self.impls[buck].ms, std::forward(key_holder), res); } @@ -202,9 +204,9 @@ class TwoLevelStringHashTable : private boost::noncopyable dispatch(*this, key_holder, typename Impl::EmplaceCallable{it, inserted}); } - LookupResult ALWAYS_INLINE find(const Key x) { return dispatch(*this, x, typename Impl::FindCallable{}); } + LookupResult ALWAYS_INLINE find(const Key & x) { return dispatch(*this, x, typename Impl::FindCallable{}); } - ConstLookupResult ALWAYS_INLINE find(const Key x) const + ConstLookupResult ALWAYS_INLINE find(const Key & x) const { return dispatch(*this, x, typename Impl::FindCallable{}); } @@ -259,6 +261,13 @@ class TwoLevelStringHashTable : private boost::noncopyable return true; } + size_t getBufferSizeInCells() const + { + size_t res = 0; + for (const auto & impl : impls) + res += impl.getBufferSizeInCells(); + return res; + } size_t getBufferSizeInBytes() const { size_t res = 0; diff --git a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp index 7193f24eddb..e407fcae764 100644 --- a/dbms/src/Flash/tests/gtest_aggregation_executor.cpp +++ b/dbms/src/Flash/tests/gtest_aggregation_executor.cpp @@ -24,6 +24,7 @@ namespace DB namespace FailPoints { extern const char force_agg_on_partial_block[]; +extern const char force_agg_prefetch[]; extern const char force_agg_two_level_hash_table_before_merge[]; } // namespace FailPoints namespace tests @@ -238,16 +239,22 @@ class AggExecutorTestRunner : public ExecutorTest ColumnWithUInt64 col_pr{1, 2, 0, 3290124, 968933, 3125, 31236, 4327, 80000}; }; -#define WRAP_FOR_AGG_PARTIAL_BLOCK_START \ - std::vector partial_blocks{true, false}; \ - for (auto partial_block : partial_blocks) \ - { \ - if (partial_block) \ - FailPointHelper::enableFailPoint(FailPoints::force_agg_on_partial_block); \ - else \ - FailPointHelper::disableFailPoint(FailPoints::force_agg_on_partial_block); +#define WRAP_FOR_AGG_FAILPOINTS_START \ + std::vector enables{true, false}; \ + for (auto enable : enables) \ + { \ + if (enable) \ + { \ + FailPointHelper::enableFailPoint(FailPoints::force_agg_on_partial_block); \ + FailPointHelper::enableFailPoint(FailPoints::force_agg_prefetch); \ + } \ + else \ + { \ + FailPointHelper::disableFailPoint(FailPoints::force_agg_on_partial_block); \ + FailPointHelper::disableFailPoint(FailPoints::force_agg_prefetch); \ + } -#define WRAP_FOR_AGG_PARTIAL_BLOCK_END } +#define WRAP_FOR_AGG_FAILPOINTS_END } /// Guarantee the correctness of group by TEST_F(AggExecutorTestRunner, GroupBy) @@ -363,9 +370,9 @@ try FailPointHelper::enableFailPoint(FailPoints::force_agg_two_level_hash_table_before_merge); else FailPointHelper::disableFailPoint(FailPoints::force_agg_two_level_hash_table_before_merge); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -429,9 +436,9 @@ try FailPointHelper::enableFailPoint(FailPoints::force_agg_two_level_hash_table_before_merge); else FailPointHelper::disableFailPoint(FailPoints::force_agg_two_level_hash_table_before_merge); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -464,9 +471,9 @@ try for (size_t i = 0; i < test_num; ++i) { request = buildDAGRequest(std::make_pair(db_name, table_name), agg_funcs[i], group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } /// Min function tests @@ -485,9 +492,9 @@ try for (size_t i = 0; i < test_num; ++i) { request = buildDAGRequest(std::make_pair(db_name, table_name), agg_funcs[i], group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } CATCH @@ -545,9 +552,9 @@ try { request = buildDAGRequest(std::make_pair(db_name, table_name), {agg_funcs[i]}, group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } CATCH @@ -615,9 +622,9 @@ try {agg_func}, group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } { @@ -629,9 +636,9 @@ try {agg_func}, group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } for (auto collation_id : {0, static_cast(TiDB::ITiDBCollator::BINARY)}) @@ -668,9 +675,9 @@ try {agg_func}, group_by_exprs[i], projections[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect_cols[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -683,9 +690,9 @@ try executeAndAssertColumnsEqual(request, {{toNullableVec({"banana"})}}); request = context.scan("aggnull_test", "t1").aggregation({}, {col("s1")}).build(context); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, {{toNullableVec("s1", {{}, "banana"})}}); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } CATCH @@ -697,9 +704,9 @@ try = {toNullableVec({3}), toNullableVec({1}), toVec({6})}; auto test_single_function = [&](size_t index) { auto request = context.scan("test_db", "test_table").aggregation({functions[index]}, {}).build(context); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, {functions_result[index]}); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END }; for (size_t i = 0; i < functions.size(); ++i) test_single_function(i); @@ -720,9 +727,9 @@ try results.push_back(functions_result[k]); auto request = context.scan("test_db", "test_table").aggregation(funcs, {}).build(context); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, results); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END funcs.pop_back(); results.pop_back(); @@ -758,9 +765,9 @@ try context.context->setSetting( "group_by_two_level_threshold", Field(static_cast(two_level_threshold))); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, expect); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -791,7 +798,7 @@ try "group_by_two_level_threshold", Field(static_cast(two_level_threshold))); context.context->setSetting("max_block_size", Field(static_cast(block_size))); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START auto blocks = getExecuteStreamsReturnBlocks(request, concurrency); size_t actual_row = 0; for (auto & block : blocks) @@ -800,7 +807,7 @@ try actual_row += block.rows(); } ASSERT_EQ(actual_row, expect_rows[i]); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -914,7 +921,7 @@ try "group_by_two_level_threshold", Field(static_cast(two_level_threshold))); context.context->setSetting("max_block_size", Field(static_cast(block_size))); - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START auto blocks = getExecuteStreamsReturnBlocks(request, concurrency); for (auto & block : blocks) { @@ -939,7 +946,7 @@ try vstackBlocks(std::move(blocks)).getColumnsWithTypeAndName(), false)); } - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } } @@ -967,18 +974,18 @@ try request = context.receive("empty_recv", 5).aggregation({Max(col("s1"))}, {col("s2")}, 5).build(context); { - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, {}); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } request = context.scan("test_db", "empty_table") .aggregation({Count(lit(Field(static_cast(1))))}, {}) .build(context); { - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(request, {toVec({0})}); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } CATCH @@ -1035,6 +1042,31 @@ try toVec("col_tinyint", col_data_tinyint), }); + std::random_device rd; + std::mt19937_64 gen(rd()); + + std::vector max_block_sizes{1, 2, DEFAULT_BLOCK_SIZE}; + std::vector two_level_thresholds{0, 1}; + + std::uniform_int_distribution dist(0, max_block_sizes.size()); + size_t random_block_size = max_block_sizes[dist(gen)]; + + std::uniform_int_distribution dist1(0, two_level_thresholds.size()); + size_t random_two_level_threshold = two_level_thresholds[dist1(gen)]; + LOG_DEBUG( + Logger::get("AggExecutorTestRunner::AggKeyOptimization"), + "max_block_size: {}, two_level_threshold: {}", + random_block_size, + random_two_level_threshold); + + context.context->setSetting("group_by_two_level_threshold_bytes", Field(static_cast(0))); +#define WRAP_FOR_AGG_CHANGE_SETTINGS \ + context.context->setSetting( \ + "group_by_two_level_threshold", \ + Field(static_cast(random_two_level_threshold))); \ + context.context->setSetting("max_block_size", Field(static_cast(random_block_size))); + + FailPointHelper::enableFailPoint(FailPoints::force_agg_prefetch); { // case-1: select count(1), col_tinyint from t group by col_int, col_tinyint // agg method: keys64(AggregationMethodKeysFixed) @@ -1049,6 +1081,7 @@ try toNullableVec("first_row(col_tinyint)", ColumnWithNullableInt8{0, 1, 2, 3}), toVec("col_int", ColumnWithInt32{0, 1, 2, 3}), toVec("col_tinyint", ColumnWithInt8{0, 1, 2, 3})}; + WRAP_FOR_AGG_CHANGE_SETTINGS executeAndAssertColumnsEqual(request, expected); } @@ -1065,6 +1098,7 @@ try = {toVec("count(1)", ColumnWithUInt64{rows_per_type, rows_per_type, rows_per_type, rows_per_type}), toNullableVec("first_row(col_int)", ColumnWithNullableInt32{0, 1, 2, 3}), toVec("col_int", ColumnWithInt32{0, 1, 2, 3})}; + WRAP_FOR_AGG_CHANGE_SETTINGS executeAndAssertColumnsEqual(request, expected); } @@ -1099,6 +1133,7 @@ try toNullableVec("first_row(col_string_with_collator)", ColumnWithNullableString{"a", "b", "c", "d"}), toVec("col_string_with_collator", ColumnWithString{"a", "b", "c", "d"}), }; + WRAP_FOR_AGG_CHANGE_SETTINGS executeAndAssertColumnsEqual(request, expected); } @@ -1116,6 +1151,7 @@ try toVec("count(1)", ColumnWithUInt64{rows_per_type, rows_per_type, rows_per_type, rows_per_type}), toVec("first_row(col_string_with_collator)", ColumnWithString{"a", "b", "c", "d"}), }; + WRAP_FOR_AGG_CHANGE_SETTINGS executeAndAssertColumnsEqual(request, expected); } @@ -1138,6 +1174,7 @@ try toVec("col_int", ColumnWithInt32{0, 1, 2, 3}), toVec("col_string_no_collator", ColumnWithString{"a", "b", "c", "d"}), }; + WRAP_FOR_AGG_CHANGE_SETTINGS executeAndAssertColumnsEqual(request, expected); } @@ -1155,8 +1192,11 @@ try toNullableVec("first_row(col_string_with_collator)", ColumnWithNullableString{"a", "b", "c", "d"}), toVec("col_string_with_collator", ColumnWithString{"a", "b", "c", "d"}), toVec("col_int", ColumnWithInt32{0, 1, 2, 3})}; + WRAP_FOR_AGG_CHANGE_SETTINGS executeAndAssertColumnsEqual(request, expected); } + FailPointHelper::disableFailPoint(FailPoints::force_agg_prefetch); +#undef WRAP_FOR_AGG_CHANGE_SETTINGS } CATCH @@ -1187,13 +1227,9 @@ try context .addExchangeReceiver("exchange_receiver_1_concurrency", column_infos, column_data, 1, partition_column_infos); - context - .addExchangeReceiver("exchange_receiver_3_concurrency", column_infos, column_data, 3, partition_column_infos); - context - .addExchangeReceiver("exchange_receiver_5_concurrency", column_infos, column_data, 5, partition_column_infos); context .addExchangeReceiver("exchange_receiver_10_concurrency", column_infos, column_data, 10, partition_column_infos); - std::vector exchange_receiver_concurrency = {1, 3, 5, 10}; + std::vector exchange_receiver_concurrency = {1, 10}; auto gen_request = [&](size_t exchange_concurrency) { return context @@ -1205,15 +1241,15 @@ try auto baseline = executeStreams(gen_request(1), 1); for (size_t exchange_concurrency : exchange_receiver_concurrency) { - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START executeAndAssertColumnsEqual(gen_request(exchange_concurrency), baseline); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END } } CATCH -#undef WRAP_FOR_AGG_PARTIAL_BLOCK_START -#undef WRAP_FOR_AGG_PARTIAL_BLOCK_END +#undef WRAP_FOR_AGG_FAILPOINTS_START +#undef WRAP_FOR_AGG_FAILPOINTS_END } // namespace tests } // namespace DB diff --git a/dbms/src/Flash/tests/gtest_compute_server.cpp b/dbms/src/Flash/tests/gtest_compute_server.cpp index 69b2242df3d..ab1680e1697 100644 --- a/dbms/src/Flash/tests/gtest_compute_server.cpp +++ b/dbms/src/Flash/tests/gtest_compute_server.cpp @@ -39,6 +39,7 @@ extern const char exception_before_mpp_root_task_run[]; extern const char exception_during_mpp_non_root_task_run[]; extern const char exception_during_mpp_root_task_run[]; extern const char exception_during_query_run[]; +extern const char force_agg_prefetch[]; } // namespace FailPoints namespace tests @@ -1843,6 +1844,7 @@ try auto_pass_through_test_data.nullable_high_ndv_tbl_name, auto_pass_through_test_data.nullable_medium_ndv_tbl_name, }; + FailPointHelper::enableFailPoint(FailPoints::force_agg_prefetch); for (const auto & tbl_name : workloads) { const String db_name = auto_pass_through_test_data.db_name; @@ -1868,6 +1870,7 @@ try res_no_pass_through); WRAP_FOR_SERVER_TEST_END } + FailPointHelper::disableFailPoint(FailPoints::force_agg_prefetch); } CATCH diff --git a/dbms/src/Flash/tests/gtest_spill_aggregation.cpp b/dbms/src/Flash/tests/gtest_spill_aggregation.cpp index b19aaf03c4c..2f1df6404fd 100644 --- a/dbms/src/Flash/tests/gtest_spill_aggregation.cpp +++ b/dbms/src/Flash/tests/gtest_spill_aggregation.cpp @@ -23,6 +23,7 @@ namespace FailPoints { extern const char force_agg_on_partial_block[]; extern const char force_thread_0_no_agg_spill[]; +extern const char force_agg_prefetch[]; } // namespace FailPoints namespace tests @@ -37,16 +38,22 @@ class SpillAggregationTestRunner : public DB::tests::ExecutorTest } }; -#define WRAP_FOR_AGG_PARTIAL_BLOCK_START \ - std::vector partial_blocks{true, false}; \ - for (auto partial_block : partial_blocks) \ - { \ - if (partial_block) \ - FailPointHelper::enableFailPoint(FailPoints::force_agg_on_partial_block); \ - else \ - FailPointHelper::disableFailPoint(FailPoints::force_agg_on_partial_block); +#define WRAP_FOR_AGG_FAILPOINTS_START \ + std::vector enables{true, false}; \ + for (auto enable : enables) \ + { \ + if (enable) \ + { \ + FailPointHelper::enableFailPoint(FailPoints::force_agg_on_partial_block); \ + FailPointHelper::enableFailPoint(FailPoints::force_agg_prefetch); \ + } \ + else \ + { \ + FailPointHelper::disableFailPoint(FailPoints::force_agg_on_partial_block); \ + FailPointHelper::disableFailPoint(FailPoints::force_agg_prefetch); \ + } -#define WRAP_FOR_AGG_PARTIAL_BLOCK_END } +#define WRAP_FOR_AGG_FAILPOINTS_END } #define WRAP_FOR_AGG_THREAD_0_NO_SPILL_START \ for (auto thread_0_no_spill : {true, false}) \ @@ -114,13 +121,13 @@ try context.context->setSetting("group_by_two_level_threshold_bytes", Field(static_cast(1))); /// don't use `executeAndAssertColumnsEqual` since it takes too long to run /// test single thread aggregation - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START WRAP_FOR_AGG_THREAD_0_NO_SPILL_START ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, 1)); /// test parallel aggregation ASSERT_COLUMNS_EQ_UR(ref_columns, executeStreams(request, original_max_streams)); WRAP_FOR_AGG_THREAD_0_NO_SPILL_END - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END /// enable spill and use small max_cached_data_bytes_in_spiller context.context->setSetting("max_cached_data_bytes_in_spiller", Field(static_cast(total_data_size / 200))); /// test single thread aggregation @@ -262,7 +269,7 @@ try Field(static_cast(max_bytes_before_external_agg))); context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); WRAP_FOR_SPILL_TEST_BEGIN - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START WRAP_FOR_AGG_THREAD_0_NO_SPILL_START auto blocks = getExecuteStreamsReturnBlocks(request, concurrency); for (auto & block : blocks) @@ -289,7 +296,7 @@ try false)); } WRAP_FOR_AGG_THREAD_0_NO_SPILL_END - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END WRAP_FOR_SPILL_TEST_END } } @@ -417,7 +424,7 @@ try Field(static_cast(max_bytes_before_external_agg))); context.context->setSetting("max_block_size", Field(static_cast(max_block_size))); WRAP_FOR_SPILL_TEST_BEGIN - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START WRAP_FOR_AGG_THREAD_0_NO_SPILL_START auto blocks = getExecuteStreamsReturnBlocks(request, concurrency); for (auto & block : blocks) @@ -444,7 +451,7 @@ try false)); } WRAP_FOR_AGG_THREAD_0_NO_SPILL_END - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END WRAP_FOR_SPILL_TEST_END } } @@ -518,9 +525,9 @@ try /// don't use `executeAndAssertColumnsEqual` since it takes too long to run auto request = gen_request(exchange_concurrency); WRAP_FOR_SPILL_TEST_BEGIN - WRAP_FOR_AGG_PARTIAL_BLOCK_START + WRAP_FOR_AGG_FAILPOINTS_START ASSERT_COLUMNS_EQ_UR(baseline, executeStreams(request, exchange_concurrency)); - WRAP_FOR_AGG_PARTIAL_BLOCK_END + WRAP_FOR_AGG_FAILPOINTS_END WRAP_FOR_SPILL_TEST_END } } @@ -528,8 +535,8 @@ CATCH #undef WRAP_FOR_SPILL_TEST_BEGIN #undef WRAP_FOR_SPILL_TEST_END -#undef WRAP_FOR_AGG_PARTIAL_BLOCK_START -#undef WRAP_FOR_AGG_PARTIAL_BLOCK_END +#undef WRAP_FOR_AGG_FAILPOINTS_START +#undef WRAP_FOR_AGG_FAILPOINTS_END } // namespace tests } // namespace DB diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index f25c22717e8..9af4ed8d466 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -43,8 +43,12 @@ extern const char random_aggregate_create_state_failpoint[]; extern const char random_aggregate_merge_failpoint[]; extern const char force_agg_on_partial_block[]; extern const char random_fail_in_resize_callback[]; +extern const char force_agg_prefetch[]; } // namespace FailPoints +static constexpr size_t agg_prefetch_step = 16; +static constexpr size_t agg_mini_batch = 256; + #define AggregationMethodName(NAME) AggregatedDataVariants::AggregationMethod_##NAME #define AggregationMethodNameTwoLevel(NAME) AggregatedDataVariants::AggregationMethod_##NAME##_two_level #define AggregationMethodType(NAME) AggregatedDataVariants::Type::NAME @@ -665,7 +669,51 @@ void NO_INLINE Aggregator::executeImpl( { typename Method::State state(agg_process_info.key_columns, key_sizes, collators); - executeImplBatch(method, state, aggregates_pool, agg_process_info); + // 2MB as prefetch threshold, because normally server L2 cache is 1MB. + static constexpr size_t prefetch_threshold = (2 << 20); +#ifndef NDEBUG + bool disable_prefetch = (method.data.getBufferSizeInBytes() < prefetch_threshold); + fiu_do_on(FailPoints::force_agg_prefetch, { disable_prefetch = false; }); +#else + const bool disable_prefetch = (method.data.getBufferSizeInBytes() < prefetch_threshold); +#endif + + if constexpr (Method::State::is_serialized_key) + { + executeImplBatch(method, state, aggregates_pool, agg_process_info); + } + else if constexpr (Method::Data::is_string_hash_map) + { + // StringHashMap doesn't support prefetch. + executeImplBatch(method, state, aggregates_pool, agg_process_info); + } + else + { + if (disable_prefetch) + executeImplBatch(method, state, aggregates_pool, agg_process_info); + else + executeImplBatch(method, state, aggregates_pool, agg_process_info); + } +} + +template +std::optional::ResultType> Aggregator::emplaceOrFindKey( + Method & method, + typename Method::State & state, + typename Method::State::Derived::KeyHolderType && key_holder, + size_t hashval) const +{ + try + { + if constexpr (only_lookup) + return state.template findKey(method.data, std::move(key_holder), hashval); + else + return state.template emplaceKey(method.data, std::move(key_holder), hashval); + } + catch (ResizeException &) + { + return {}; + } } template @@ -689,7 +737,30 @@ std::optional::Res } } -template +template +ALWAYS_INLINE inline void prepareBatch( + size_t row_idx, + size_t end_row, + std::vector & hashvals, + std::vector & key_holders, + Arena * aggregates_pool, + std::vector & sort_key_containers, + Method & method, + typename Method::State & state) +{ + assert(hashvals.size() == key_holders.size()); + + for (size_t i = row_idx, j = 0; i < row_idx + hashvals.size() && i < end_row; ++i, ++j) + { + key_holders[j] = static_cast(&state)->getKeyHolder( + i, + aggregates_pool, + sort_key_containers); + hashvals[j] = method.data.hash(keyHolderGetKey(key_holders[j])); + } +} + +template ALWAYS_INLINE void Aggregator::executeImplBatch( Method & method, typename Method::State & state, @@ -699,63 +770,29 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( // collect_hit_rate and only_lookup cannot be true at the same time. static_assert(!(collect_hit_rate && only_lookup)); - std::vector sort_key_containers; - sort_key_containers.resize(params.keys_size, ""); - size_t agg_size = agg_process_info.end_row - agg_process_info.start_row; - fiu_do_on(FailPoints::force_agg_on_partial_block, { - if (agg_size > 0 && agg_process_info.start_row == 0) - agg_size = std::max(agg_size / 2, 1); - }); - /// Optimization for special case when there are no aggregate functions. if (params.aggregates_size == 0) - { - /// For all rows. - AggregateDataPtr place = aggregates_pool->alloc(0); - for (size_t i = 0; i < agg_size; ++i) - { - auto emplace_result_hold = emplaceOrFindKey( - method, - state, - agg_process_info.start_row, - *aggregates_pool, - sort_key_containers); - if likely (emplace_result_hold.has_value()) - { - if constexpr (collect_hit_rate) - { - ++agg_process_info.hit_row_cnt; - } - - if constexpr (only_lookup) - { - if (!emplace_result_hold.value().isFound()) - agg_process_info.not_found_rows.push_back(i); - } - else - { - emplace_result_hold.value().setMapped(place); - } - ++agg_process_info.start_row; - } - else - { - LOG_INFO(log, "HashTable resize throw ResizeException since the data is already marked for spill"); - break; - } - } - return; - } + return handleOneBatch( + method, + state, + agg_process_info, + aggregates_pool); /// Optimization for special case when aggregating by 8bit key. if constexpr (std::is_same_v) { + size_t rows = agg_process_info.end_row - agg_process_info.start_row; + fiu_do_on(FailPoints::force_agg_on_partial_block, { + if (rows > 0 && agg_process_info.start_row == 0) + rows = std::max(rows / 2, 1); + }); + for (AggregateFunctionInstruction * inst = agg_process_info.aggregate_functions_instructions.data(); inst->that; ++inst) { inst->batch_that->addBatchLookupTable8( agg_process_info.start_row, - agg_size, + rows, reinterpret_cast(method.data.data()), inst->state_offset, [&](AggregateDataPtr & aggregate_data) { @@ -767,12 +804,12 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( inst->batch_arguments, aggregates_pool); } - agg_process_info.start_row += agg_size; + agg_process_info.start_row += rows; // For key8, assume all rows are hit. No need to do state switch for auto pass through hashagg. // Because HashMap of key8 is basically a vector of size 256. if constexpr (collect_hit_rate) - agg_process_info.hit_row_cnt = agg_size; + agg_process_info.hit_row_cnt = rows; // Because all rows are hit, so state will not switch to Selective. if constexpr (only_lookup) @@ -781,77 +818,174 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( } /// Generic case. + return handleOneBatch( + method, + state, + agg_process_info, + aggregates_pool); +} + +template +void Aggregator::handleOneBatch( + Method & method, + typename Method::State & state, + AggProcessInfo & agg_process_info, + Arena * aggregates_pool) const +{ + std::vector sort_key_containers; + sort_key_containers.resize(params.keys_size, ""); + size_t rows = agg_process_info.end_row - agg_process_info.start_row; + fiu_do_on(FailPoints::force_agg_on_partial_block, { + if (rows > 0 && agg_process_info.start_row == 0) + rows = std::max(rows / 2, 1); + }); - std::unique_ptr places(new AggregateDataPtr[agg_size]); std::optional processed_rows; + std::optional::ResultType> emplace_result_holder; + + std::unique_ptr places{}; + // It's ok to use fake address, because no one use agg data if there is no agg func. + auto * place = reinterpret_cast(0x1); + if constexpr (compute_agg_data) + places = std::unique_ptr(new AggregateDataPtr[rows]); - for (size_t i = agg_process_info.start_row; i < agg_process_info.start_row + agg_size; ++i) + size_t i = agg_process_info.start_row; + const size_t end = agg_process_info.start_row + rows; + + size_t mini_batch_size = rows; + std::vector hashvals; + std::vector key_holders; + if constexpr (enable_prefetch) { - AggregateDataPtr aggregate_data = nullptr; + // mini batch will only be used when HashTable is big(a.k.a enable_prefetch is true), + // which can reduce cache miss of agg data. + mini_batch_size = agg_mini_batch; + hashvals.resize(agg_mini_batch); + key_holders.resize(agg_mini_batch); + } - auto emplace_result_holder - = emplaceOrFindKey(method, state, i, *aggregates_pool, sort_key_containers); - if unlikely (!emplace_result_holder.has_value()) + // i is the begin row index of each mini batch. + while (i < end) + { + if constexpr (enable_prefetch) { - LOG_INFO(log, "HashTable resize throw ResizeException since the data is already marked for spill"); - break; - } + if unlikely (i + mini_batch_size > end) + mini_batch_size = end - i; - auto & emplace_result = emplace_result_holder.value(); + prepareBatch(i, end, hashvals, key_holders, aggregates_pool, sort_key_containers, method, state); + } - if constexpr (only_lookup) + const auto cur_batch_end = i + mini_batch_size; + // j is the row index of Column. + // k is the index of hashvals/key_holders. + for (size_t j = i, k = 0; j < cur_batch_end; ++j, ++k) { - if (emplace_result.isFound()) + AggregateDataPtr aggregate_data = nullptr; + const size_t index_relative_to_start_row = j - agg_process_info.start_row; + if constexpr (enable_prefetch) { - aggregate_data = emplace_result.getMapped(); + if likely (k + agg_prefetch_step < hashvals.size()) + method.data.prefetch(hashvals[k + agg_prefetch_step]); + + emplace_result_holder + = emplaceOrFindKey(method, state, std::move(key_holders[k]), hashvals[k]); } else { - agg_process_info.not_found_rows.push_back(i); + emplace_result_holder + = emplaceOrFindKey(method, state, j, *aggregates_pool, sort_key_containers); } - } - else - { - /// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key. - if (emplace_result.isInserted()) - { - /// exception-safety - if you can not allocate memory or create states, then destructors will not be called. - emplace_result.setMapped(nullptr); - aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); - createAggregateStates(aggregate_data); + if unlikely (!emplace_result_holder.has_value()) + { + LOG_INFO(log, "HashTable resize throw ResizeException since the data is already marked for spill"); + break; + } - emplace_result.setMapped(aggregate_data); + auto & emplace_result = emplace_result_holder.value(); + if constexpr (only_lookup) + { + if constexpr (compute_agg_data) + { + if (emplace_result.isFound()) + { + aggregate_data = emplace_result.getMapped(); + } + else + { + agg_process_info.not_found_rows.push_back(j); + } + } + else + { + if (!emplace_result.isFound()) + agg_process_info.not_found_rows.push_back(i); + } } else { - aggregate_data = emplace_result.getMapped(); + if constexpr (compute_agg_data) + { + if (emplace_result.isInserted()) + { + // exception-safety - if you can not allocate memory or create states, then destructors will not be called. + emplace_result.setMapped(nullptr); + + aggregate_data + = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); + createAggregateStates(aggregate_data); + + emplace_result.setMapped(aggregate_data); + } + else + { + aggregate_data = emplace_result.getMapped(); - if constexpr (collect_hit_rate) - ++agg_process_info.hit_row_cnt; + if constexpr (collect_hit_rate) + ++agg_process_info.hit_row_cnt; + + if constexpr (enable_prefetch) + __builtin_prefetch(aggregate_data); + } + } + else + { + emplace_result.setMapped(place); + } } + if constexpr (compute_agg_data) + places[index_relative_to_start_row] = aggregate_data; + processed_rows = j; } - places[i - agg_process_info.start_row] = aggregate_data; - processed_rows = i; - } + if unlikely (!processed_rows.has_value()) + break; - if (processed_rows) - { - /// Add values to the aggregate functions. - for (AggregateFunctionInstruction * inst = agg_process_info.aggregate_functions_instructions.data(); inst->that; - ++inst) + const size_t processed_size = *processed_rows - i + 1; + if constexpr (compute_agg_data) { - inst->batch_that->addBatch( - agg_process_info.start_row, - *processed_rows - agg_process_info.start_row + 1, - places.get(), - inst->state_offset, - inst->batch_arguments, - aggregates_pool); + for (AggregateFunctionInstruction * inst = agg_process_info.aggregate_functions_instructions.data(); + inst->that; + ++inst) + { + inst->batch_that->addBatch( + i, + processed_size, + places.get() + i - agg_process_info.start_row, + inst->state_offset, + inst->batch_arguments, + aggregates_pool); + } } - agg_process_info.start_row = *processed_rows + 1; + + if unlikely (processed_size != mini_batch_size) + break; + + i = cur_batch_end; } + + if likely (processed_rows) + agg_process_info.start_row = *processed_rows + 1; } void NO_INLINE @@ -876,7 +1010,6 @@ Aggregator::executeWithoutKeyImpl(AggregatedDataWithoutKey & res, AggProcessInfo agg_process_info.start_row += agg_size; } - void Aggregator::prepareAggregateInstructions( Columns columns, AggregateColumns & aggregate_columns, @@ -1610,6 +1743,7 @@ void NO_INLINE Aggregator::convertToBlockImplFinal( agg_keys_helper.initAggKeys(data.size(), key_columns); } + // Doesn't prefetch agg data, because places[data.size()] is needed, which can be very large. data.forEachValue([&](const auto & key [[maybe_unused]], auto & mapped) { if constexpr (!skip_convert_key) { @@ -1687,16 +1821,45 @@ void NO_INLINE Aggregator::convertToBlocksImplFinal( } size_t data_index = 0; + const auto rows = data.size(); + std::unique_ptr places(new AggregateDataPtr[rows]); + + size_t current_bound = params.max_block_size; + size_t key_columns_vec_index = 0; + data.forEachValue([&](const auto & key [[maybe_unused]], auto & mapped) { - size_t key_columns_vec_index = data_index / params.max_block_size; if constexpr (!skip_convert_key) { agg_keys_helpers[key_columns_vec_index] ->insertKeyIntoColumns(key, key_columns_vec[key_columns_vec_index], key_sizes_ref, params.collators); } - insertAggregatesIntoColumns(mapped, final_aggregate_columns_vec[key_columns_vec_index], arena); + places[data_index] = mapped; ++data_index; + + if unlikely (data_index == current_bound) + { + ++key_columns_vec_index; + current_bound += params.max_block_size; + } }); + + data_index = 0; + current_bound = params.max_block_size; + key_columns_vec_index = 0; + while (data_index < rows) + { + if likely (data_index + agg_prefetch_step < rows) + __builtin_prefetch(places[data_index + agg_prefetch_step]); + + insertAggregatesIntoColumns(places[data_index], final_aggregate_columns_vec[key_columns_vec_index], arena); + ++data_index; + + if unlikely (data_index == current_bound) + { + ++key_columns_vec_index; + current_bound += params.max_block_size; + } + } } template diff --git a/dbms/src/Interpreters/Aggregator.h b/dbms/src/Interpreters/Aggregator.h index 381bfba8462..b9c1c18f484 100644 --- a/dbms/src/Interpreters/Aggregator.h +++ b/dbms/src/Interpreters/Aggregator.h @@ -231,8 +231,7 @@ struct AggregationMethodStringNoCache : data(other.data) {} - using State = ColumnsHashing:: - HashMethodString; + using State = ColumnsHashing::HashMethodString; template struct EmplaceOrFindKeyResult { @@ -528,7 +527,7 @@ struct AggregationMethodFixedStringNoCache : data(other.data) {} - using State = ColumnsHashing::HashMethodFixedString; + using State = ColumnsHashing::HashMethodFixedString; template struct EmplaceOrFindKeyResult { @@ -1454,13 +1453,27 @@ class Aggregator AggProcessInfo & agg_process_info, TiDB::TiDBCollators & collators) const; - template + template void executeImplBatch( Method & method, typename Method::State & state, Arena * aggregates_pool, AggProcessInfo & agg_process_info) const; + template + void handleOneBatch( + Method & method, + typename Method::State & state, + AggProcessInfo & agg_process_info, + Arena * aggregates_pool) const; + + template + std::optional::ResultType> emplaceOrFindKey( + Method & method, + typename Method::State & state, + typename Method::State::Derived::KeyHolderType && key_holder, + size_t hashval) const; + template std::optional::ResultType> emplaceOrFindKey( Method & method, diff --git a/dbms/src/Interpreters/JoinPartition.cpp b/dbms/src/Interpreters/JoinPartition.cpp index a060878c4f7..294c72c19a3 100644 --- a/dbms/src/Interpreters/JoinPartition.cpp +++ b/dbms/src/Interpreters/JoinPartition.cpp @@ -412,7 +412,7 @@ struct KeyGetterForTypeImpl template struct KeyGetterForTypeImpl { - using Type = ColumnsHashing::HashMethodString; + using Type = ColumnsHashing::HashMethodString; }; template struct KeyGetterForTypeImpl @@ -427,7 +427,7 @@ struct KeyGetterForTypeImpl template struct KeyGetterForTypeImpl { - using Type = ColumnsHashing::HashMethodFixedString; + using Type = ColumnsHashing::HashMethodFixedString; }; template struct KeyGetterForTypeImpl @@ -652,18 +652,18 @@ void NO_INLINE insertBlockIntoMapsTypeCase( insert_indexes.emplace_back(insert_index); } -#define INSERT_TO_MAP(join_partition, segment_index) \ - auto & current_map = (join_partition)->getHashMap(); \ - for (auto & s_i : (segment_index)) \ - { \ - Inserter::insert( \ - current_map, \ - key_getter, \ - stored_block, \ - s_i, \ - pool, \ - sort_key_containers, \ - probe_cache_column_threshold); \ +#define INSERT_TO_MAP(join_partition, segment_index) \ + auto & current_map = (join_partition) -> getHashMap(); \ + for (auto & s_i : (segment_index)) \ + { \ + Inserter::insert( \ + current_map, \ + key_getter, \ + stored_block, \ + s_i, \ + pool, \ + sort_key_containers, \ + probe_cache_column_threshold); \ } #define INSERT_TO_NOT_INSERTED_MAP \ diff --git a/dbms/src/Interpreters/SetVariants.h b/dbms/src/Interpreters/SetVariants.h index a1591f8c13a..5c503240b7b 100644 --- a/dbms/src/Interpreters/SetVariants.h +++ b/dbms/src/Interpreters/SetVariants.h @@ -54,7 +54,7 @@ struct SetMethodString Data data; - using State = ColumnsHashing::HashMethodString; + using State = ColumnsHashing::HashMethodString; }; template @@ -77,7 +77,7 @@ struct SetMethodFixedString Data data; - using State = ColumnsHashing::HashMethodFixedString; + using State = ColumnsHashing::HashMethodFixedString; }; namespace set_impl diff --git a/libs/libcommon/include/common/StringRef.h b/libs/libcommon/include/common/StringRef.h index a87b54a7670..bf1ab026a49 100644 --- a/libs/libcommon/include/common/StringRef.h +++ b/libs/libcommon/include/common/StringRef.h @@ -180,7 +180,7 @@ inline size_t hashLessThan8(const char * data, size_t size) struct CRC32Hash { - size_t operator()(StringRef x) const + static size_t operator()(const StringRef & x) { const char * pos = x.data; size_t size = x.size;