From aeb7fe483edbc8d3d314794a195274a93230dccf Mon Sep 17 00:00:00 2001 From: Emil Ernerfeldt Date: Tue, 26 Nov 2024 10:23:36 +0100 Subject: [PATCH] Test that shrink_to_fit actually frees memory --- arrow/tests/shrink_to_fit.rs | 134 +++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 arrow/tests/shrink_to_fit.rs diff --git a/arrow/tests/shrink_to_fit.rs b/arrow/tests/shrink_to_fit.rs new file mode 100644 index 00000000000..51901cad8c3 --- /dev/null +++ b/arrow/tests/shrink_to_fit.rs @@ -0,0 +1,134 @@ +use arrow::{ + array::{Array, ArrayRef, ListArray, PrimitiveArray}, + buffer::OffsetBuffer, + datatypes::{Field, UInt8Type}, +}; + +/// Test that `shrink_to_fit` frees memory after concatenating a large number of arrays. +#[test] +fn test_shrink_to_fit_after_concat() { + let array_len = 6_000; + let num_concats = 100; + + let primitive_array: PrimitiveArray = (0..array_len) + .map(|v| (v % 255) as u8) + .collect::>() + .into(); + let primitive_array: ArrayRef = Arc::new(primitive_array); + + let list_array: ArrayRef = Arc::new(ListArray::new( + Field::new_list_field(primitive_array.data_type().clone(), false).into(), + OffsetBuffer::from_lengths([primitive_array.len()]), + primitive_array.clone(), + None, + )); + + // Num bytes allocated globally and by this thread, respectively. + let (concatenated, _bytes_allocated_globally, bytes_allocated_by_this_thread) = + memory_use(|| { + let mut concatenated = concatenate(num_concats, list_array.clone()); + concatenated.shrink_to_fit(); // This is what we're testing! + dbg!(concatenated.data_type()); + concatenated + }); + let expected_len = num_concats * array_len; + assert_eq!(bytes_used(concatenated.clone()), expected_len); + eprintln!("The concatenated array is {expected_len} B long. Amount of memory used by this thread: {bytes_allocated_by_this_thread} B"); + + assert!( + expected_len <= bytes_allocated_by_this_thread, + "We must allocate at least as much space as the concatenated array" + ); + assert!( + bytes_allocated_by_this_thread <= expected_len + expected_len / 100, + "We shouldn't have more than 1% memory overhead. In fact, we are using {bytes_allocated_by_this_thread}B of memory for {expected_len}B of data" + ); +} + +fn concatenate(num_times: usize, array: ArrayRef) -> ArrayRef { + let mut concatenated = array.clone(); + for _ in 0..num_times - 1 { + concatenated = arrow::compute::kernels::concat::concat(&[&*concatenated, &*array]).unwrap(); + } + concatenated +} + +fn bytes_used(array: ArrayRef) -> usize { + let mut array = array; + loop { + match array.data_type() { + arrow::datatypes::DataType::UInt8 => break, + arrow::datatypes::DataType::List(_) => { + let list = array.as_any().downcast_ref::().unwrap(); + array = list.values().clone(); + } + _ => unreachable!(), + } + } + + array.len() +} + +// --- Memory tracking --- + +use std::sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, +}; + +static LIVE_BYTES_GLOBAL: AtomicUsize = AtomicUsize::new(0); + +thread_local! { + static LIVE_BYTES_IN_THREAD: AtomicUsize = const { AtomicUsize::new(0) } ; +} + +pub struct TrackingAllocator { + allocator: std::alloc::System, +} + +#[global_allocator] +pub static GLOBAL_ALLOCATOR: TrackingAllocator = TrackingAllocator { + allocator: std::alloc::System, +}; + +#[allow(unsafe_code)] +// SAFETY: +// We just do book-keeping and then let another allocator do all the actual work. +unsafe impl std::alloc::GlobalAlloc for TrackingAllocator { + #[allow(clippy::let_and_return)] + unsafe fn alloc(&self, layout: std::alloc::Layout) -> *mut u8 { + LIVE_BYTES_IN_THREAD.with(|bytes| bytes.fetch_add(layout.size(), Relaxed)); + LIVE_BYTES_GLOBAL.fetch_add(layout.size(), Relaxed); + + // SAFETY: + // Just deferring + unsafe { self.allocator.alloc(layout) } + } + + unsafe fn dealloc(&self, ptr: *mut u8, layout: std::alloc::Layout) { + LIVE_BYTES_IN_THREAD.with(|bytes| bytes.fetch_sub(layout.size(), Relaxed)); + LIVE_BYTES_GLOBAL.fetch_sub(layout.size(), Relaxed); + + // SAFETY: + // Just deferring + unsafe { self.allocator.dealloc(ptr, layout) }; + } +} + +fn live_bytes_local() -> usize { + LIVE_BYTES_IN_THREAD.with(|bytes| bytes.load(Relaxed)) +} + +fn live_bytes_global() -> usize { + LIVE_BYTES_GLOBAL.load(Relaxed) +} + +/// Returns `(num_bytes_allocated, num_bytes_allocated_by_this_thread)`. +fn memory_use(run: impl Fn() -> R) -> (R, usize, usize) { + let used_bytes_start_local = live_bytes_local(); + let used_bytes_start_global = live_bytes_global(); + let ret = run(); + let bytes_used_local = live_bytes_local() - used_bytes_start_local; + let bytes_used_global = live_bytes_global() - used_bytes_start_global; + (ret, bytes_used_global, bytes_used_local) +}