-
Notifications
You must be signed in to change notification settings - Fork 821
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Test that shrink_to_fit actually frees memory
- Loading branch information
Showing
1 changed file
with
134 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
use arrow::{ | ||
array::{Array, ArrayRef, ListArray, PrimitiveArray}, | ||
buffer::OffsetBuffer, | ||
datatypes::{Field, UInt8Type}, | ||
}; | ||
|
||
/// Test that `shrink_to_fit` frees memory after concatenating a large number of arrays. | ||
#[test] | ||
fn test_shrink_to_fit_after_concat() { | ||
let array_len = 6_000; | ||
let num_concats = 100; | ||
|
||
let primitive_array: PrimitiveArray<UInt8Type> = (0..array_len) | ||
.map(|v| (v % 255) as u8) | ||
.collect::<Vec<_>>() | ||
.into(); | ||
let primitive_array: ArrayRef = Arc::new(primitive_array); | ||
|
||
let list_array: ArrayRef = Arc::new(ListArray::new( | ||
Field::new_list_field(primitive_array.data_type().clone(), false).into(), | ||
OffsetBuffer::from_lengths([primitive_array.len()]), | ||
primitive_array.clone(), | ||
None, | ||
)); | ||
|
||
// Num bytes allocated globally and by this thread, respectively. | ||
let (concatenated, _bytes_allocated_globally, bytes_allocated_by_this_thread) = | ||
memory_use(|| { | ||
let mut concatenated = concatenate(num_concats, list_array.clone()); | ||
concatenated.shrink_to_fit(); // This is what we're testing! | ||
dbg!(concatenated.data_type()); | ||
concatenated | ||
}); | ||
let expected_len = num_concats * array_len; | ||
assert_eq!(bytes_used(concatenated.clone()), expected_len); | ||
eprintln!("The concatenated array is {expected_len} B long. Amount of memory used by this thread: {bytes_allocated_by_this_thread} B"); | ||
|
||
assert!( | ||
expected_len <= bytes_allocated_by_this_thread, | ||
"We must allocate at least as much space as the concatenated array" | ||
); | ||
assert!( | ||
bytes_allocated_by_this_thread <= expected_len + expected_len / 100, | ||
"We shouldn't have more than 1% memory overhead. In fact, we are using {bytes_allocated_by_this_thread}B of memory for {expected_len}B of data" | ||
); | ||
} | ||
|
||
fn concatenate(num_times: usize, array: ArrayRef) -> ArrayRef { | ||
let mut concatenated = array.clone(); | ||
for _ in 0..num_times - 1 { | ||
concatenated = arrow::compute::kernels::concat::concat(&[&*concatenated, &*array]).unwrap(); | ||
} | ||
concatenated | ||
} | ||
|
||
fn bytes_used(array: ArrayRef) -> usize { | ||
let mut array = array; | ||
loop { | ||
match array.data_type() { | ||
arrow::datatypes::DataType::UInt8 => break, | ||
arrow::datatypes::DataType::List(_) => { | ||
let list = array.as_any().downcast_ref::<ListArray>().unwrap(); | ||
array = list.values().clone(); | ||
} | ||
_ => unreachable!(), | ||
} | ||
} | ||
|
||
array.len() | ||
} | ||
|
||
// --- Memory tracking --- | ||
|
||
use std::sync::{ | ||
atomic::{AtomicUsize, Ordering::Relaxed}, | ||
Arc, | ||
}; | ||
|
||
static LIVE_BYTES_GLOBAL: AtomicUsize = AtomicUsize::new(0); | ||
|
||
thread_local! { | ||
static LIVE_BYTES_IN_THREAD: AtomicUsize = const { AtomicUsize::new(0) } ; | ||
} | ||
|
||
pub struct TrackingAllocator { | ||
allocator: std::alloc::System, | ||
} | ||
|
||
#[global_allocator] | ||
pub static GLOBAL_ALLOCATOR: TrackingAllocator = TrackingAllocator { | ||
allocator: std::alloc::System, | ||
}; | ||
|
||
#[allow(unsafe_code)] | ||
// SAFETY: | ||
// We just do book-keeping and then let another allocator do all the actual work. | ||
unsafe impl std::alloc::GlobalAlloc for TrackingAllocator { | ||
#[allow(clippy::let_and_return)] | ||
unsafe fn alloc(&self, layout: std::alloc::Layout) -> *mut u8 { | ||
LIVE_BYTES_IN_THREAD.with(|bytes| bytes.fetch_add(layout.size(), Relaxed)); | ||
LIVE_BYTES_GLOBAL.fetch_add(layout.size(), Relaxed); | ||
|
||
// SAFETY: | ||
// Just deferring | ||
unsafe { self.allocator.alloc(layout) } | ||
} | ||
|
||
unsafe fn dealloc(&self, ptr: *mut u8, layout: std::alloc::Layout) { | ||
LIVE_BYTES_IN_THREAD.with(|bytes| bytes.fetch_sub(layout.size(), Relaxed)); | ||
LIVE_BYTES_GLOBAL.fetch_sub(layout.size(), Relaxed); | ||
|
||
// SAFETY: | ||
// Just deferring | ||
unsafe { self.allocator.dealloc(ptr, layout) }; | ||
} | ||
} | ||
|
||
fn live_bytes_local() -> usize { | ||
LIVE_BYTES_IN_THREAD.with(|bytes| bytes.load(Relaxed)) | ||
} | ||
|
||
fn live_bytes_global() -> usize { | ||
LIVE_BYTES_GLOBAL.load(Relaxed) | ||
} | ||
|
||
/// Returns `(num_bytes_allocated, num_bytes_allocated_by_this_thread)`. | ||
fn memory_use<R>(run: impl Fn() -> R) -> (R, usize, usize) { | ||
let used_bytes_start_local = live_bytes_local(); | ||
let used_bytes_start_global = live_bytes_global(); | ||
let ret = run(); | ||
let bytes_used_local = live_bytes_local() - used_bytes_start_local; | ||
let bytes_used_global = live_bytes_global() - used_bytes_start_global; | ||
(ret, bytes_used_global, bytes_used_local) | ||
} |