From 207e855f235401b997269fa858af641cf2a7b81e Mon Sep 17 00:00:00 2001
From: Marco Neumann <marco@crepererum.net>
Date: Fri, 22 Nov 2024 04:23:15 +0100
Subject: [PATCH] refactor: change some `hashbrown` `RawTable` uses to
 `HashTable` (#13514)

* feat: add `HashTableAllocExt`

This is similar to `RawTableAllocExt` and will help #13256.

* refactor: convert `ArrowBytesMap` to `HashTable`

For #13256.

* refactor: convert `ArrowBytesViewMap` to `HashTable`

For #13256.
---
 datafusion/common/src/utils/proxy.rs          | 73 ++++++++++++++++++-
 datafusion/execution/src/memory_pool/mod.rs   |  4 +-
 .../physical-expr-common/src/binary_map.rs    | 10 +--
 .../src/binary_view_map.rs                    |  8 +-
 4 files changed, 84 insertions(+), 11 deletions(-)

diff --git a/datafusion/common/src/utils/proxy.rs b/datafusion/common/src/utils/proxy.rs
index 5d14a1517129..b32164f682fa 100644
--- a/datafusion/common/src/utils/proxy.rs
+++ b/datafusion/common/src/utils/proxy.rs
@@ -17,7 +17,10 @@
 
 //! [`VecAllocExt`] and [`RawTableAllocExt`] to help tracking of memory allocations
 
-use hashbrown::raw::{Bucket, RawTable};
+use hashbrown::{
+    hash_table::HashTable,
+    raw::{Bucket, RawTable},
+};
 use std::mem::size_of;
 
 /// Extension trait for [`Vec`] to account for allocations.
@@ -173,3 +176,71 @@ impl<T> RawTableAllocExt for RawTable<T> {
         }
     }
 }
+
+/// Extension trait for hash browns [`HashTable`] to account for allocations.
+pub trait HashTableAllocExt {
+    /// Item type.
+    type T;
+
+    /// Insert new element into table and increase
+    /// `accounting` by any newly allocated bytes.
+    ///
+    /// Returns the bucket where the element was inserted.
+    /// Note that allocation counts capacity, not size.
+    ///
+    /// # Example:
+    /// ```
+    /// # use datafusion_common::utils::proxy::HashTableAllocExt;
+    /// # use hashbrown::hash_table::HashTable;
+    /// let mut table = HashTable::new();
+    /// let mut allocated = 0;
+    /// let hash_fn = |x: &u32| (*x as u64) % 1000;
+    /// // pretend 0x3117 is the hash value for 1
+    /// table.insert_accounted(1, hash_fn, &mut allocated);
+    /// assert_eq!(allocated, 64);
+    ///
+    /// // insert more values
+    /// for i in 0..100 { table.insert_accounted(i, hash_fn, &mut allocated); }
+    /// assert_eq!(allocated, 400);
+    /// ```
+    fn insert_accounted(
+        &mut self,
+        x: Self::T,
+        hasher: impl Fn(&Self::T) -> u64,
+        accounting: &mut usize,
+    );
+}
+
+impl<T> HashTableAllocExt for HashTable<T>
+where
+    T: Eq,
+{
+    type T = T;
+
+    fn insert_accounted(
+        &mut self,
+        x: Self::T,
+        hasher: impl Fn(&Self::T) -> u64,
+        accounting: &mut usize,
+    ) {
+        let hash = hasher(&x);
+
+        // NOTE: `find_entry` does NOT grow!
+        match self.find_entry(hash, |y| y == &x) {
+            Ok(_occupied) => {}
+            Err(_absent) => {
+                if self.len() == self.capacity() {
+                    // need to request more memory
+                    let bump_elements = self.capacity().max(16);
+                    let bump_size = bump_elements * size_of::<T>();
+                    *accounting = (*accounting).checked_add(bump_size).expect("overflow");
+
+                    self.reserve(bump_elements, &hasher);
+                }
+
+                // still need to insert the element since first try failed
+                self.entry(hash, |y| y == &x, hasher).insert(x);
+            }
+        }
+    }
+}
diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs
index 5bf30b724d0b..45d467f133bf 100644
--- a/datafusion/execution/src/memory_pool/mod.rs
+++ b/datafusion/execution/src/memory_pool/mod.rs
@@ -23,7 +23,9 @@ use std::{cmp::Ordering, sync::Arc};
 
 mod pool;
 pub mod proxy {
-    pub use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
+    pub use datafusion_common::utils::proxy::{
+        HashTableAllocExt, RawTableAllocExt, VecAllocExt,
+    };
 }
 
 pub use pool::*;
diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs
index 59280a3abbdb..8febbdd5b1f9 100644
--- a/datafusion/physical-expr-common/src/binary_map.rs
+++ b/datafusion/physical-expr-common/src/binary_map.rs
@@ -28,7 +28,7 @@ use arrow::array::{
 use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
 use arrow::datatypes::DataType;
 use datafusion_common::hash_utils::create_hashes;
-use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
+use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
 use std::any::type_name;
 use std::fmt::Debug;
 use std::mem::{size_of, swap};
@@ -215,7 +215,7 @@ where
     /// Should the output be String or Binary?
     output_type: OutputType,
     /// Underlying hash set for each distinct value
-    map: hashbrown::raw::RawTable<Entry<O, V>>,
+    map: hashbrown::hash_table::HashTable<Entry<O, V>>,
     /// Total size of the map in bytes
     map_size: usize,
     /// In progress arrow `Buffer` containing all values
@@ -246,7 +246,7 @@ where
     pub fn new(output_type: OutputType) -> Self {
         Self {
             output_type,
-            map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY),
+            map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY),
             map_size: 0,
             buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY),
             offsets: vec![O::default()], // first offset is always 0
@@ -387,7 +387,7 @@ where
                 let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize);
 
                 // is value is already present in the set?
-                let entry = self.map.get_mut(hash, |header| {
+                let entry = self.map.find_mut(hash, |header| {
                     // compare value if hashes match
                     if header.len != value_len {
                         return false;
@@ -425,7 +425,7 @@ where
             // value is not "small"
             else {
                 // Check if the value is already present in the set
-                let entry = self.map.get_mut(hash, |header| {
+                let entry = self.map.find_mut(hash, |header| {
                     // compare value if hashes match
                     if header.len != value_len {
                         return false;
diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs
index 8af35510dd6c..4148c5ffa7c7 100644
--- a/datafusion/physical-expr-common/src/binary_view_map.rs
+++ b/datafusion/physical-expr-common/src/binary_view_map.rs
@@ -24,7 +24,7 @@ use arrow::array::cast::AsArray;
 use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder};
 use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType};
 use datafusion_common::hash_utils::create_hashes;
-use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
+use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt};
 use std::fmt::Debug;
 use std::sync::Arc;
 
@@ -122,7 +122,7 @@ where
     /// Should the output be StringView or BinaryView?
     output_type: OutputType,
     /// Underlying hash set for each distinct value
-    map: hashbrown::raw::RawTable<Entry<V>>,
+    map: hashbrown::hash_table::HashTable<Entry<V>>,
     /// Total size of the map in bytes
     map_size: usize,
 
@@ -148,7 +148,7 @@ where
     pub fn new(output_type: OutputType) -> Self {
         Self {
             output_type,
-            map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY),
+            map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY),
             map_size: 0,
             builder: GenericByteViewBuilder::new(),
             random_state: RandomState::new(),
@@ -274,7 +274,7 @@ where
             // get the value as bytes
             let value: &[u8] = value.as_ref();
 
-            let entry = self.map.get_mut(hash, |header| {
+            let entry = self.map.find_mut(hash, |header| {
                 let v = self.builder.get_value(header.view_idx);
 
                 if v.len() != value.len() {