Skip to content

Commit

Permalink
Add Src: FromBytes in try_transmute_mut!, Self: IntoBytes to `T…
Browse files Browse the repository at this point in the history
…ryFromBytes::try_mut*` (#2343)

* Enforce `Src: FromBytes` in `try_transmute_mut!` (#2229)

Ensures that the source reference remains valid after the
transmuted (and possibly mutated)  destination is dropped.

Makes progress on #2226

gherrit-pr-id: I425e7d5103cb3b2a9e7107bf9df0743dca2e08cb

* Add `Self: IntoBytes` bound to `TryFromBytes::try_mut*`

Consider that `MaybeUninit<u8>` is `TryFromBytes`. If a `&mut [u8]`
is cast into a `&mut MaybeUninit<u8>`, then uninit bytes are written,
the shadowed `&mut [u8]`'s referent will no longer be valid.

Makes progress towards #2226 and #1866.

gherrit-pr-id: Ib233c4d0643e0690c53a37a08d9845e5fe43249c

---------

Co-authored-by: Jack Wrenn <[email protected]>
Co-authored-by: Jack Wrenn <[email protected]>
  • Loading branch information
3 people authored Feb 14, 2025
1 parent 17e7e4d commit c43bbed
Show file tree
Hide file tree
Showing 17 changed files with 806 additions and 116 deletions.
50 changes: 46 additions & 4 deletions src/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,24 @@ mod tests {
}
}

pub(super) trait TestTryFromMut<T: ?Sized> {
#[allow(clippy::needless_lifetimes)]
fn test_try_from_mut<'bytes>(
&self,
bytes: &'bytes mut [u8],
) -> Option<Option<&'bytes mut T>>;
}

impl<T: TryFromBytes + IntoBytes + KnownLayout + ?Sized> TestTryFromMut<T> for AutorefWrapper<T> {
#[allow(clippy::needless_lifetimes)]
fn test_try_from_mut<'bytes>(
&self,
bytes: &'bytes mut [u8],
) -> Option<Option<&'bytes mut T>> {
Some(T::try_mut_from_bytes(bytes).ok())
}
}

pub(super) trait TestTryReadFrom<T> {
fn test_try_read_from(&self, bytes: &[u8]) -> Option<Option<T>>;
}
Expand Down Expand Up @@ -1255,6 +1273,25 @@ mod tests {
None
}

#[allow(clippy::needless_lifetimes)]
fn test_try_from_mut<'bytes>(&mut self, _bytes: &'bytes mut [u8]) -> Option<Option<&'bytes mut $ty>> {
assert_on_allowlist!(
test_try_from_mut($ty):
Option<Box<UnsafeCell<NotZerocopy>>>,
Option<&'static UnsafeCell<NotZerocopy>>,
Option<&'static mut UnsafeCell<NotZerocopy>>,
Option<NonNull<UnsafeCell<NotZerocopy>>>,
Option<fn()>,
Option<FnManyArgs>,
Option<extern "C" fn()>,
Option<ECFnManyArgs>,
*const NotZerocopy,
*mut NotZerocopy
);

None
}

fn test_try_read_from(&mut self, _bytes: &[u8]) -> Option<Option<&$ty>> {
assert_on_allowlist!(
test_try_read_from($ty):
Expand Down Expand Up @@ -1363,8 +1400,10 @@ mod tests {
let bytes_mut = &mut vec.as_mut_slice()[offset..offset+size];
bytes_mut.copy_from_slice(bytes);

let res = <$ty as TryFromBytes>::try_mut_from_bytes(bytes_mut);
assert!(res.is_ok(), "{}::try_mut_from_bytes({:?}): got `Err`, expected `Ok`", stringify!($ty), val);
let res = ww.test_try_from_mut(bytes_mut);
if let Some(res) = res {
assert!(res.is_some(), "{}::try_mut_from_bytes({:?}): got `None`, expected `Some`", stringify!($ty), val);
}
}

let res = bytes.and_then(|bytes| ww.test_try_read_from(bytes));
Expand All @@ -1384,8 +1423,11 @@ mod tests {
assert!(res.is_none(), "{}::try_ref_from_bytes({:?}): got Some, expected None", stringify!($ty), c);
}

let res = <$ty as TryFromBytes>::try_mut_from_bytes(c);
assert!(res.is_err(), "{}::try_mut_from_bytes({:?}): got Ok, expected Err", stringify!($ty), c);
let res = w.test_try_from_mut(c);
if let Some(res) = res {
assert!(res.is_none(), "{}::try_mut_from_bytes({:?}): got Some, expected None", stringify!($ty), c);
}


let res = w.test_try_read_from(c);
if let Some(res) = res {
Expand Down
86 changes: 43 additions & 43 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1713,8 +1713,8 @@ pub unsafe trait TryFromBytes {
/// use zerocopy::*;
/// # use zerocopy_derive::*;
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct ZSTy {
/// leading_sized: [u8; 2],
/// trailing_dst: [()],
Expand All @@ -1731,17 +1731,17 @@ pub unsafe trait TryFromBytes {
/// # use zerocopy_derive::*;
///
/// // The only valid value of this type is the byte `0xC0`
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(u8)]
/// enum C0 { xC0 = 0xC0 }
///
/// // The only valid value of this type is the bytes `0xC0C0`.
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C)]
/// struct C0C0(C0, C0);
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct Packet {
/// magic_number: C0C0,
/// mug_size: u8,
Expand Down Expand Up @@ -1769,7 +1769,7 @@ pub unsafe trait TryFromBytes {
#[inline]
fn try_mut_from_bytes(bytes: &mut [u8]) -> Result<&mut Self, TryCastError<&mut [u8], Self>>
where
Self: KnownLayout,
Self: KnownLayout + IntoBytes,
{
static_assert_dst_is_not_zst!(Self);
match Ptr::from_mut(bytes).try_cast_into_no_leftover::<Self, BecauseExclusive>(None) {
Expand Down Expand Up @@ -1821,8 +1821,8 @@ pub unsafe trait TryFromBytes {
/// use zerocopy::*;
/// # use zerocopy_derive::*;
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct ZSTy {
/// leading_sized: [u8; 2],
/// trailing_dst: [()],
Expand All @@ -1839,17 +1839,17 @@ pub unsafe trait TryFromBytes {
/// # use zerocopy_derive::*;
///
/// // The only valid value of this type is the byte `0xC0`
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(u8)]
/// enum C0 { xC0 = 0xC0 }
///
/// // The only valid value of this type is the bytes `0xC0C0`.
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C)]
/// struct C0C0(C0, C0);
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct Packet {
/// magic_number: C0C0,
/// mug_size: u8,
Expand Down Expand Up @@ -1882,7 +1882,7 @@ pub unsafe trait TryFromBytes {
source: &mut [u8],
) -> Result<(&mut Self, &mut [u8]), TryCastError<&mut [u8], Self>>
where
Self: KnownLayout,
Self: KnownLayout + IntoBytes,
{
static_assert_dst_is_not_zst!(Self);
try_mut_from_prefix_suffix(source, CastType::Prefix, None)
Expand Down Expand Up @@ -1916,8 +1916,8 @@ pub unsafe trait TryFromBytes {
/// use zerocopy::*;
/// # use zerocopy_derive::*;
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct ZSTy {
/// leading_sized: u16,
/// trailing_dst: [()],
Expand All @@ -1934,17 +1934,17 @@ pub unsafe trait TryFromBytes {
/// # use zerocopy_derive::*;
///
/// // The only valid value of this type is the byte `0xC0`
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(u8)]
/// enum C0 { xC0 = 0xC0 }
///
/// // The only valid value of this type is the bytes `0xC0C0`.
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C)]
/// struct C0C0(C0, C0);
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct Packet {
/// magic_number: C0C0,
/// mug_size: u8,
Expand Down Expand Up @@ -1977,7 +1977,7 @@ pub unsafe trait TryFromBytes {
source: &mut [u8],
) -> Result<(&mut [u8], &mut Self), TryCastError<&mut [u8], Self>>
where
Self: KnownLayout,
Self: KnownLayout + IntoBytes,
{
static_assert_dst_is_not_zst!(Self);
try_mut_from_prefix_suffix(source, CastType::Suffix, None).map(swap)
Expand Down Expand Up @@ -2286,17 +2286,17 @@ pub unsafe trait TryFromBytes {
/// # use zerocopy_derive::*;
///
/// // The only valid value of this type is the byte `0xC0`
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(u8)]
/// enum C0 { xC0 = 0xC0 }
///
/// // The only valid value of this type is the bytes `0xC0C0`.
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C)]
/// struct C0C0(C0, C0);
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct Packet {
/// magic_number: C0C0,
/// mug_size: u8,
Expand Down Expand Up @@ -2330,8 +2330,8 @@ pub unsafe trait TryFromBytes {
/// use zerocopy::*;
/// # use zerocopy_derive::*;
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct ZSTy {
/// leading_sized: NonZeroU16,
/// trailing_dst: [()],
Expand All @@ -2351,7 +2351,7 @@ pub unsafe trait TryFromBytes {
count: usize,
) -> Result<&mut Self, TryCastError<&mut [u8], Self>>
where
Self: KnownLayout<PointerMetadata = usize>,
Self: KnownLayout<PointerMetadata = usize> + IntoBytes,
{
match Ptr::from_mut(source).try_cast_into_no_leftover::<Self, BecauseExclusive>(Some(count))
{
Expand Down Expand Up @@ -2397,17 +2397,17 @@ pub unsafe trait TryFromBytes {
/// # use zerocopy_derive::*;
///
/// // The only valid value of this type is the byte `0xC0`
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(u8)]
/// enum C0 { xC0 = 0xC0 }
///
/// // The only valid value of this type is the bytes `0xC0C0`.
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C)]
/// struct C0C0(C0, C0);
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct Packet {
/// magic_number: C0C0,
/// mug_size: u8,
Expand Down Expand Up @@ -2443,8 +2443,8 @@ pub unsafe trait TryFromBytes {
/// use zerocopy::*;
/// # use zerocopy_derive::*;
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct ZSTy {
/// leading_sized: NonZeroU16,
/// trailing_dst: [()],
Expand All @@ -2464,7 +2464,7 @@ pub unsafe trait TryFromBytes {
count: usize,
) -> Result<(&mut Self, &mut [u8]), TryCastError<&mut [u8], Self>>
where
Self: KnownLayout<PointerMetadata = usize>,
Self: KnownLayout<PointerMetadata = usize> + IntoBytes,
{
try_mut_from_prefix_suffix(source, CastType::Prefix, Some(count))
}
Expand Down Expand Up @@ -2492,17 +2492,17 @@ pub unsafe trait TryFromBytes {
/// # use zerocopy_derive::*;
///
/// // The only valid value of this type is the byte `0xC0`
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(u8)]
/// enum C0 { xC0 = 0xC0 }
///
/// // The only valid value of this type is the bytes `0xC0C0`.
/// #[derive(TryFromBytes, KnownLayout)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C)]
/// struct C0C0(C0, C0);
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct Packet {
/// magic_number: C0C0,
/// mug_size: u8,
Expand Down Expand Up @@ -2538,8 +2538,8 @@ pub unsafe trait TryFromBytes {
/// use zerocopy::*;
/// # use zerocopy_derive::*;
///
/// #[derive(TryFromBytes, KnownLayout)]
/// #[repr(C)]
/// #[derive(TryFromBytes, IntoBytes, KnownLayout)]
/// #[repr(C, packed)]
/// struct ZSTy {
/// leading_sized: NonZeroU16,
/// trailing_dst: [()],
Expand All @@ -2559,7 +2559,7 @@ pub unsafe trait TryFromBytes {
count: usize,
) -> Result<(&mut [u8], &mut Self), TryCastError<&mut [u8], Self>>
where
Self: KnownLayout<PointerMetadata = usize>,
Self: KnownLayout<PointerMetadata = usize> + IntoBytes,
{
try_mut_from_prefix_suffix(source, CastType::Suffix, Some(count)).map(swap)
}
Expand Down Expand Up @@ -2771,7 +2771,7 @@ fn try_ref_from_prefix_suffix<T: TryFromBytes + KnownLayout + Immutable + ?Sized
}

#[inline(always)]
fn try_mut_from_prefix_suffix<T: TryFromBytes + KnownLayout + ?Sized>(
fn try_mut_from_prefix_suffix<T: IntoBytes + TryFromBytes + KnownLayout + ?Sized>(
candidate: &mut [u8],
cast_type: CastType,
meta: Option<T::PointerMetadata>,
Expand Down
22 changes: 11 additions & 11 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ macro_rules! transmute_ref {
/// const fn transmute_mut<'src, 'dst, Src, Dst>(src: &'src mut Src) -> &'dst mut Dst
/// where
/// 'src: 'dst,
/// Src: FromBytes + IntoBytes + Immutable,
/// Dst: FromBytes + IntoBytes + Immutable,
/// Src: FromBytes + IntoBytes,
/// Dst: FromBytes + IntoBytes,
/// size_of::<Src>() == size_of::<Dst>(),
/// align_of::<Src>() >= align_of::<Dst>(),
/// {
Expand Down Expand Up @@ -325,9 +325,9 @@ macro_rules! transmute_mut {
#[allow(unused, clippy::diverging_sub_expression)]
if false {
// This branch, though never taken, ensures that the type of `e` is
// `&mut T` where `T: 't + Sized + FromBytes + IntoBytes + Immutable`
// and that the type of this macro expression is `&mut U` where `U:
// 'u + Sized + FromBytes + IntoBytes + Immutable`.
// `&mut T` where `T: 't + Sized + FromBytes + IntoBytes` and that
// the type of this macro expression is `&mut U` where `U: 'u +
// Sized + FromBytes + IntoBytes`.

// We use immutable references here rather than mutable so that, if
// this macro is used in a const context (in which, as of this
Expand Down Expand Up @@ -577,8 +577,8 @@ macro_rules! try_transmute_ref {
/// ```ignore
/// fn try_transmute_mut<Src, Dst>(src: &mut Src) -> Result<&mut Dst, ValidityError<&mut Src, Dst>>
/// where
/// Src: IntoBytes,
/// Dst: TryFromBytes,
/// Src: FromBytes + IntoBytes,
/// Dst: TryFromBytes + IntoBytes,
/// size_of::<Src>() == size_of::<Dst>(),
/// align_of::<Src>() >= align_of::<Dst>(),
/// {
Expand Down Expand Up @@ -888,9 +888,9 @@ mod tests {
#[test]
fn test_try_transmute_mut() {
// Test that memory is transmuted with `try_transmute_mut` as expected.
let array_of_bools = &mut [false, true, false, true, false, true, false, true];
let array_of_u8s = &mut [0u8, 1, 0, 1, 0, 1, 0, 1];
let array_of_arrays = &mut [[0u8, 1], [0, 1], [0, 1], [0, 1]];
let x: Result<&mut [[u8; 2]; 4], _> = try_transmute_mut!(array_of_bools);
let x: Result<&mut [[u8; 2]; 4], _> = try_transmute_mut!(array_of_u8s);
assert_eq!(x, Ok(array_of_arrays));

let array_of_bools = &mut [false, true, false, true, false, true, false, true];
Expand All @@ -903,8 +903,8 @@ mod tests {
let array_of_bools = &mut [false, true, false, true, false, true, false, true];
let array_of_arrays = &mut [[0u8, 1], [0, 1], [0, 1], [0, 1]];
{
let x: Result<&mut [[u8; 2]; 4], _> = try_transmute_mut!(array_of_bools);
assert_eq!(x, Ok(array_of_arrays));
let x: Result<&mut [bool; 8], _> = try_transmute_mut!(array_of_arrays);
assert_eq!(x, Ok(array_of_bools));
}

// Test that `try_transmute_mut!` supports decreasing alignment.
Expand Down
Loading

0 comments on commit c43bbed

Please sign in to comment.