Skip to content

Commit

Permalink
Use memcpy to decode Vec<int>.
Browse files Browse the repository at this point in the history
  • Loading branch information
caibear committed Mar 22, 2024
1 parent 0bb9642 commit 5655c92
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 49 deletions.
13 changes: 5 additions & 8 deletions src/bool.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::coder::{Buffer, Decoder, Encoder, Result, View};
use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl};
use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, SliceImpl, Unaligned, VecImpl};
use crate::pack::{pack_bools, unpack_bools};
use std::num::NonZeroUsize;

Expand Down Expand Up @@ -41,13 +41,10 @@ impl<'a> View<'a> for BoolDecoder<'a> {

impl<'a> Decoder<'a, bool> for BoolDecoder<'a> {
#[inline(always)]
fn as_primitive_ptr(&self) -> Option<*const u8> {
Some(self.0.ref_slice().as_ptr() as *const u8)
}

#[inline(always)]
unsafe fn as_primitive_advance(&mut self, n: usize) {
self.0.mut_slice().advance(n);
fn as_primitive(&mut self) -> Option<&mut SliceImpl<Unaligned<bool>>> {
// Safety: `Unaligned<bool>` is equivalent to bool since it's a `#[repr(C, packed)]` wrapper
// around bool and both have size/align of 1.
unsafe { Some(std::mem::transmute(self.0.mut_slice())) }
}

#[inline(always)]
Expand Down
17 changes: 5 additions & 12 deletions src/coder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::fast::VecImpl;
use crate::fast::{SliceImpl, Unaligned, VecImpl};
use std::mem::MaybeUninit;
use std::num::NonZeroUsize;

Expand All @@ -25,7 +25,7 @@ pub trait Buffer {
pub const MAX_VECTORED_CHUNK: usize = 64;

pub trait Encoder<T: ?Sized>: Buffer + Default {
/// Returns a `VecImpl<T>` if `T` is a type that can be encoded by copying.
/// Returns a `&mut VecImpl<T>` if `T` is a type that can be encoded by copying.
#[inline(always)]
fn as_primitive(&mut self) -> Option<&mut VecImpl<T>>
where
Expand Down Expand Up @@ -67,20 +67,13 @@ pub trait View<'a> {
/// One of [`Decoder::decode`] and [`Decoder::decode_in_place`] must be implemented or calling
/// either one will result in infinite recursion and a stack overflow.
pub trait Decoder<'a, T>: View<'a> + Default {
/// Returns a pointer to the current element if it can be decoded by copying.
/// Returns a `&mut SliceImpl<Unaligned<T>>` if `T` is a type that can be decoded by copying.
/// Uses `Unaligned<T>` so `IntDecoder` can borrow from input `[u8]`.
#[inline(always)]
fn as_primitive_ptr(&self) -> Option<*const u8> {
fn as_primitive(&mut self) -> Option<&mut SliceImpl<Unaligned<T>>> {
None
}

/// Assuming [`Self::as_primitive_ptr`] returns `Some(ptr)`, this advances `ptr` by `n`.
/// # Safety
/// Can only decode `self.populate(_, length)` items.
unsafe fn as_primitive_advance(&mut self, n: usize) {
let _ = n;
unreachable!();
}

/// Decodes a single value. Can't error since `View::populate` has already validated the input.
/// Prefer decode for primitives (since it's simpler) and decode_in_place for array/struct/tuple.
/// # Safety
Expand Down
33 changes: 12 additions & 21 deletions src/derive/vec.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::coder::{Buffer, Decoder, Encoder, Result, View, MAX_VECTORED_CHUNK};
use crate::derive::{Decode, Encode};
use crate::fast::Unaligned;
use crate::length::{LengthDecoder, LengthEncoder};
use std::collections::{BTreeSet, BinaryHeap, HashSet, LinkedList, VecDeque};
use std::hash::{BuildHasher, Hash};
Expand Down Expand Up @@ -40,7 +41,8 @@ impl<T: Encode> Buffer for VecEncoder<T> {

/// Copies `N` or `n` bytes from `src` to `dst` depending on if `src` lies within a memory page.
/// https://stackoverflow.com/questions/37800739/is-it-safe-to-read-past-the-end-of-a-buffer-within-the-same-page-on-x86-and-x64
/// Safety: Same as [`copy_nonoverlapping_unaligned`] but with the additional requirements that
/// # Safety
/// Same as [`std::ptr::copy_nonoverlapping`] but with the additional requirements that
/// `n != 0 && n <= N` and `dst` has room for a `[T; N]`.
/// Is a macro instead of an `#[inline(always)] fn` because it optimizes better.
macro_rules! unsafe_wild_copy {
Expand All @@ -62,31 +64,18 @@ macro_rules! unsafe_wild_copy {
));

if within_page {
std::ptr::write_unaligned($dst as *mut std::mem::MaybeUninit<[$T; $N]>,
std::ptr::read_unaligned($src as *const std::mem::MaybeUninit<[$T; $N]>)
);
*($dst as *mut std::mem::MaybeUninit<[$T; $N]>) = std::ptr::read($src as *const std::mem::MaybeUninit<[$T; $N]>);
} else {
#[cold]
unsafe fn cold<T>(src: *const T, dst: *mut T, n: usize) {
crate::derive::vec::copy_nonoverlapping_unaligned(src, dst, n);
src.copy_to_nonoverlapping(dst, n);
}
cold($src, $dst, $n);
}
}
}
pub(crate) use unsafe_wild_copy;

/// Equivalent to `std::ptr::copy_nonoverlapping` but neither `src` nor `dst` has to be aligned.
/// Safety: Same as [`std::ptr::copy_nonoverlapping`], but without any alignment requirements.
#[inline(always)]
pub unsafe fn copy_nonoverlapping_unaligned<T>(src: *const T, dst: *mut T, n: usize) {
std::ptr::copy_nonoverlapping(
src as *const u8,
dst as *mut u8,
n * std::mem::size_of::<T>(),
);
}

impl<T: Encode> VecEncoder<T> {
/// Copy fixed size slices. Much faster than memcpy.
#[inline(never)]
Expand Down Expand Up @@ -143,7 +132,7 @@ impl<T: Encode> VecEncoder<T> {
let n = s.len();
primitives.reserve(n);
let ptr = primitives.end_ptr();
copy_nonoverlapping_unaligned(s.as_ptr(), ptr, n);
s.as_ptr().copy_to_nonoverlapping(ptr, n);
primitives.set_end_ptr(ptr.add(n));
});
}
Expand All @@ -159,7 +148,7 @@ impl<T: Encode> Encoder<[T]> for VecEncoder<T> {
primitive.reserve(n);
unsafe {
let ptr = primitive.end_ptr();
copy_nonoverlapping_unaligned(v.as_ptr(), ptr, n);
v.as_ptr().copy_to_nonoverlapping(ptr, n);
primitive.set_end_ptr(ptr.add(n));
}
} else if let Some(n) = NonZeroUsize::new(n) {
Expand Down Expand Up @@ -300,10 +289,12 @@ impl<'a, T: Decode<'a>> Decoder<'a, Vec<T>> for VecDecoder<'a, T> {
}

let v = out.write(Vec::with_capacity(length));
if let Some(primitive) = self.elements.as_primitive_ptr() {
if let Some(primitive) = self.elements.as_primitive() {
unsafe {
copy_nonoverlapping_unaligned(primitive as *const T, v.as_mut_ptr(), length);
self.elements.as_primitive_advance(length);
primitive
.as_ptr()
.copy_to_nonoverlapping(v.as_mut_ptr() as *mut Unaligned<T>, length);
primitive.advance(length);
}
} else {
let spare = v.spare_capacity_mut();
Expand Down
22 changes: 22 additions & 0 deletions src/fast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,20 @@ impl<'a, T> FastSlice<'a, T> {
pub fn as_ptr(&self) -> *const T {
self.ptr
}

/// Casts `&mut FastSlice<T>` to `&mut FastSlice<B>`.
#[inline(always)]
pub fn cast<B>(&mut self) -> &mut FastSlice<'a, B>
where
T: bytemuck::Pod,
B: bytemuck::Pod,
{
use std::mem::*;
assert_eq!(size_of::<T>(), size_of::<B>());
assert_eq!(align_of::<T>(), align_of::<B>());
// Safety: size/align are equal and both are bytemuck::Pod.
unsafe { transmute(self) }
}
}

pub trait NextUnchecked<'a, T: Copy> {
Expand Down Expand Up @@ -459,6 +473,14 @@ impl<'a, T> std::ops::DerefMut for SetOwned<'a, '_, T> {
}
}

#[derive(Copy, Clone)]
#[repr(C, packed)]
pub struct Unaligned<T>(T);

// Could derive with bytemuck/derive.
unsafe impl<T: bytemuck::Zeroable> bytemuck::Zeroable for Unaligned<T> {}
unsafe impl<T: bytemuck::Pod> bytemuck::Pod for Unaligned<T> {}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
30 changes: 22 additions & 8 deletions src/int.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::coder::{Buffer, Decoder, Encoder, Result, View};
use crate::error::err;
use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl};
use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, SliceImpl, Unaligned, VecImpl};
use crate::pack_ints::{pack_ints, unpack_ints, Int};
use bytemuck::{CheckedBitPattern, NoUninit, Pod};
use std::marker::PhantomData;
Expand Down Expand Up @@ -59,6 +59,11 @@ impl<'a, T: Int> View<'a> for IntDecoder<'a, T> {

// Makes IntDecoder<u32> able to decode i32/f32 (but not char since it can fail).
impl<'a, T: Int, P: Pod> Decoder<'a, P> for IntDecoder<'a, T> {
#[inline(always)]
fn as_primitive(&mut self) -> Option<&mut SliceImpl<Unaligned<P>>> {
Some(self.0.mut_slice().cast())
}

#[inline(always)]
fn decode(&mut self) -> P {
let v = unsafe { self.0.mut_slice().next_unchecked() };
Expand All @@ -81,7 +86,6 @@ where
<C as CheckedBitPattern>::Bits: Pod,
{
fn populate(&mut self, input: &mut &'a [u8], length: usize) -> Result<()> {
assert_eq!(std::mem::size_of::<C>(), std::mem::size_of::<I>());
self.0.populate(input, length)?;

let mut decoder = self.0.borrowed_clone();
Expand All @@ -97,13 +101,23 @@ where
<C as CheckedBitPattern>::Bits: Pod,
{
#[inline(always)]
fn decode(&mut self) -> C {
let i: I = self.0.decode();
fn as_primitive(&mut self) -> Option<&mut SliceImpl<Unaligned<C>>> {
self.0
.as_primitive()
.map(|p: &mut SliceImpl<'_, Unaligned<I>>| {
let p = p.cast::<Unaligned<C::Bits>>();
// Safety: `Unaligned<C::Bits>` and `Unaligned<C>` have the same layout and populate
// ensured C's bit pattern is valid.
unsafe { std::mem::transmute(p) }
})
}

// Safety: populate ensures:
// - C and I are of the same size.
// - The checked bit pattern of C is valid.
unsafe { std::mem::transmute_copy(&i) }
#[inline(always)]
fn decode(&mut self) -> C {
let v: I = self.0.decode();
let v: C::Bits = bytemuck::must_cast(v);
// Safety: C::Bits and C have the same layout and populate ensured C's bit pattern is valid.
unsafe { std::mem::transmute_copy(&v) }
}
}

Expand Down

0 comments on commit 5655c92

Please sign in to comment.