Skip to content

Commit

Permalink
fix: use T aligned pointer in TempFdArray
Browse files Browse the repository at this point in the history
  • Loading branch information
Erigara committed Jan 31, 2024
1 parent 4939f73 commit 55fa419
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 18 deletions.
60 changes: 44 additions & 16 deletions src/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::convert::TryInto;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::io::{Read, Seek, SeekFrom, Write};
use std::mem::ManuallyDrop;

use crate::frames::UnresolvedFrames;

Expand Down Expand Up @@ -148,6 +149,7 @@ pub struct TempFdArray<T: 'static> {
file: NamedTempFile,
buffer: Box<[T; BUFFER_LENGTH]>,
buffer_index: usize,
flush_n: usize,
}

impl<T: Default + Debug> TempFdArray<T> {
Expand All @@ -162,6 +164,7 @@ impl<T: Default + Debug> TempFdArray<T> {
file,
buffer,
buffer_index: 0,
flush_n: 0,
})
}
}
Expand All @@ -175,6 +178,7 @@ impl<T> TempFdArray<T> {
BUFFER_LENGTH * std::mem::size_of::<T>(),
)
};
self.flush_n += 1;
self.file.write_all(buf)?;

Ok(())
Expand All @@ -191,24 +195,50 @@ impl<T> TempFdArray<T> {
Ok(())
}

fn try_iter(&self) -> std::io::Result<impl Iterator<Item = &T>> {
let mut file_vec = Vec::new();
let mut file = self.file.reopen()?;
file.seek(SeekFrom::Start(0))?;
file.read_to_end(&mut file_vec)?;
file.seek(SeekFrom::End(0))?;
fn try_iter<'lt>(&'lt self, file_buffer_container: &'lt mut Option<Box<[ManuallyDrop<T>]>>) -> std::io::Result<impl Iterator<Item = &'lt T>> {
let file_buffer = self.file_buffer()?;
let file_buffer = file_buffer_container.insert(file_buffer);

Ok(TempFdArrayIterator {
buffer: &self.buffer[0..self.buffer_index],
file_vec,
file_buffer,
index: 0,
})
}

fn file_buffer(&self) -> std::io::Result<Box<[ManuallyDrop<T>]>> {
if self.flush_n == 0 {
return Ok(Vec::new().into_boxed_slice())
}

let mut file = self.file.reopen()?;
file.seek(SeekFrom::Start(0))?;
let file_buffer = unsafe {
// Get properly aligned pointer
let len = BUFFER_LENGTH * self.flush_n;
// Expect T to be non-ZST
let layout = std::alloc::Layout::array::<ManuallyDrop<T>>(len).unwrap();
let ptr = std::alloc::alloc(layout);
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
// Populate with bytes
file.read_exact(std::slice::from_raw_parts_mut(
ptr,
len * std::mem::size_of::<T>(),
))?;
// Cast to proper type
Box::from_raw(std::ptr::slice_from_raw_parts_mut(ptr.cast::<ManuallyDrop<T>>(), len))
};
file.seek(SeekFrom::End(0))?;

Ok(file_buffer)
}
}

pub struct TempFdArrayIterator<'a, T> {
pub buffer: &'a [T],
pub file_vec: Vec<u8>,
pub file_buffer: &'a [ManuallyDrop<T>],
pub index: usize,
}

Expand All @@ -220,12 +250,9 @@ impl<'a, T> Iterator for TempFdArrayIterator<'a, T> {
self.index += 1;
Some(&self.buffer[self.index - 1])
} else {
let length = self.file_vec.len() / std::mem::size_of::<T>();
let ts =
unsafe { std::slice::from_raw_parts(self.file_vec.as_ptr() as *const T, length) };
if self.index - self.buffer.len() < ts.len() {
if self.index - self.buffer.len() < self.file_buffer.len() {
self.index += 1;
Some(&ts[self.index - self.buffer.len() - 1])
Some(&self.file_buffer[self.index - self.buffer.len() - 1])
} else {
None
}
Expand Down Expand Up @@ -256,8 +283,8 @@ impl<T: Hash + Eq + 'static> Collector<T> {
Ok(())
}

pub fn try_iter(&self) -> std::io::Result<impl Iterator<Item = &Entry<T>>> {
Ok(self.map.iter().chain(self.temp_array.try_iter()?))
pub fn try_iter<'lt>(&'lt self, file_buffer_store: &'lt mut Option<Box<[ManuallyDrop<Entry<T>>]>>) -> std::io::Result<impl Iterator<Item = &'lt Entry<T>>> {
Ok(self.map.iter().chain(self.temp_array.try_iter(file_buffer_store)?))
}
}

Expand Down Expand Up @@ -343,7 +370,8 @@ mod tests {
}
}

collector.try_iter().unwrap().for_each(|entry| {
let mut file_buffer_store = None;
collector.try_iter(&mut file_buffer_store).unwrap().for_each(|entry| {
test_utils::add_map(&mut real_map, entry);
});

Expand Down
6 changes: 4 additions & 2 deletions src/report.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ impl<'a> ReportBuilder<'a> {
Err(Error::CreatingError)
}
Ok(profiler) => {
profiler.data.try_iter()?.for_each(|entry| {
let mut file_buffer_store = None;
profiler.data.try_iter(&mut file_buffer_store)?.for_each(|entry| {
let count = entry.count;
if count > 0 {
let key = &entry.item;
Expand Down Expand Up @@ -107,7 +108,8 @@ impl<'a> ReportBuilder<'a> {
Err(Error::CreatingError)
}
Ok(profiler) => {
profiler.data.try_iter()?.for_each(|entry| {
let mut file_buffer_store = None;
profiler.data.try_iter(&mut file_buffer_store)?.for_each(|entry| {
let count = entry.count;
if count > 0 {
let mut key = Frames::from(entry.item.clone());
Expand Down

0 comments on commit 55fa419

Please sign in to comment.