From 7b094f0deae3bb1c0efe7c51b8f094113992c133 Mon Sep 17 00:00:00 2001 From: David Renshaw Date: Tue, 15 Oct 2024 14:08:19 -0400 Subject: [PATCH] add more checking for copy_non_overlapping() calls --- capnp/src/private/layout.rs | 29 +++++++++++++++++------------ capnpc/test/test.rs | 10 +++++++--- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/capnp/src/private/layout.rs b/capnp/src/private/layout.rs index 455e3f9e9..dcd1285e8 100644 --- a/capnp/src/private/layout.rs +++ b/capnp/src/private/layout.rs @@ -404,6 +404,13 @@ mod wire_helpers { arena.amplified_read(virtual_amount) } + #[inline] + pub unsafe fn copy_nonoverlapping_check_zero(src: *const T, dst: *mut T, count: usize) { + if count > 0 { + ptr::copy_nonoverlapping(src, dst, count); + } + } + #[inline] pub unsafe fn allocate( arena: &mut dyn BuilderArena, @@ -840,7 +847,7 @@ mod wire_helpers { data_size: isize, pointer_count: isize, ) { - ptr::copy_nonoverlapping(src, dst, data_size as usize * BYTES_PER_WORD); + copy_nonoverlapping_check_zero(src, dst, data_size as usize * BYTES_PER_WORD); let src_refs: *const WirePointer = (src as *const WirePointer).offset(data_size); let dst_refs: *mut WirePointer = (dst as *mut WirePointer).offset(data_size); @@ -909,7 +916,7 @@ mod wire_helpers { let src_ptr = WirePointer::target(src); let (dst_ptr, dst, segment_id) = allocate(arena, dst, segment_id, word_count, WirePointerKind::List); - ptr::copy_nonoverlapping( + copy_nonoverlapping_check_zero( src_ptr, dst_ptr, word_count as usize * BYTES_PER_WORD, @@ -1202,7 +1209,7 @@ mod wire_helpers { // Copy data section. // Note: copy_nonoverlapping's third argument is an element count, not a byte count. - ptr::copy_nonoverlapping(old_ptr, ptr, old_data_size as usize * BYTES_PER_WORD); + copy_nonoverlapping_check_zero(old_ptr, ptr, old_data_size as usize * BYTES_PER_WORD); //# Copy pointer section. let new_pointer_section: *mut WirePointer = @@ -1564,7 +1571,7 @@ mod wire_helpers { let mut dst = new_ptr as *mut WirePointer; for _ in 0..element_count { // Copy data section. - ptr::copy_nonoverlapping(src, dst, old_data_size as usize); + copy_nonoverlapping_check_zero(src, dst, old_data_size as usize); // Copy pointer section let new_pointer_section = dst.offset(new_data_size as isize); @@ -1674,7 +1681,7 @@ mod wire_helpers { let mut src: *mut u8 = old_ptr; let old_byte_step = old_data_size / BITS_PER_BYTE as u32; for _ in 0..element_count { - ptr::copy_nonoverlapping(src, dst, old_byte_step as usize); + copy_nonoverlapping_check_zero(src, dst, old_byte_step as usize); src = src.offset(old_byte_step as isize); dst = dst.offset(new_step as isize * BYTES_PER_WORD as isize); } @@ -1831,7 +1838,7 @@ mod wire_helpers { value: &[u8], ) -> SegmentAnd> { let allocation = init_data_pointer(arena, reff, segment_id, value.len() as u32); - ptr::copy_nonoverlapping(value.as_ptr(), allocation.value.as_mut_ptr(), value.len()); + copy_nonoverlapping_check_zero(value.as_ptr(), allocation.value.as_mut_ptr(), value.len()); allocation } @@ -1937,7 +1944,7 @@ mod wire_helpers { *ptr = u8::from(value.get_bool_field(0)) } } else { - ptr::copy_nonoverlapping::(value.data, ptr, data_size as usize); + copy_nonoverlapping_check_zero::(value.data, ptr, data_size as usize); } let pointer_section: *mut WirePointer = @@ -2032,9 +2039,7 @@ mod wire_helpers { // in the canonicalize=true case. let whole_byte_size = u64::from(value.element_count) * u64::from(value.step) / BITS_PER_BYTE as u64; - if whole_byte_size > 0 { - ptr::copy_nonoverlapping(value.ptr, ptr, whole_byte_size as usize); - } + copy_nonoverlapping_check_zero(value.ptr, ptr, whole_byte_size as usize); let leftover_bits = u64::from(value.element_count) * u64::from(value.step) % BITS_PER_BYTE as u64; @@ -2114,7 +2119,7 @@ mod wire_helpers { let mut src: *const u8 = value.ptr; for _ in 0..value.element_count { - ptr::copy_nonoverlapping(src, dst, data_size as usize * BYTES_PER_WORD); + copy_nonoverlapping_check_zero(src, dst, data_size as usize * BYTES_PER_WORD); dst = dst.offset(data_size as isize * BYTES_PER_WORD as isize); src = src.offset(decl_data_size as isize * BYTES_PER_WORD as isize); @@ -3738,7 +3743,7 @@ impl<'a> StructBuilder<'a> { if shared_data_size == 1 { self.set_bool_field(0, other.get_bool_field(0)); } else { - ptr::copy_nonoverlapping( + wire_helpers::copy_nonoverlapping_check_zero( other.data, self.data, (shared_data_size / BITS_PER_BYTE as u32) as usize, diff --git a/capnpc/test/test.rs b/capnpc/test/test.rs index 9a38893b1..957da3c8a 100644 --- a/capnpc/test/test.rs +++ b/capnpc/test/test.rs @@ -1527,16 +1527,20 @@ mod tests { /// https://github.com/capnproto/capnproto-rust/issues/525 #[test] - fn set_data_list_null() { + fn copy_nonoverlapping_null() { use crate::test_capnp::test_all_types; let mut message = message::Builder::new_default(); - let root: test_all_types::Builder<'_> = message.init_root(); + let mut root: test_all_types::Builder<'_> = message.init_root(); let mut message2 = message::Builder::new_default(); let mut root2: test_all_types::Builder<'_> = message2.init_root(); root2 - .set_data_list(root.into_reader().get_data_list().unwrap()) + .set_data_list(root.reborrow().into_reader().get_data_list().unwrap()) + .unwrap(); + + root2 + .set_struct_field(root.into_reader().get_struct_field().unwrap()) .unwrap(); }