Skip to content

Commit

Permalink
add more checking for copy_non_overlapping() calls
Browse files Browse the repository at this point in the history
  • Loading branch information
dwrensha committed Oct 15, 2024
1 parent 04d2dd6 commit 7b094f0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
29 changes: 17 additions & 12 deletions capnp/src/private/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,13 @@ mod wire_helpers {
arena.amplified_read(virtual_amount)
}

#[inline]
pub unsafe fn copy_nonoverlapping_check_zero<T>(src: *const T, dst: *mut T, count: usize) {
if count > 0 {
ptr::copy_nonoverlapping(src, dst, count);
}
}

#[inline]
pub unsafe fn allocate(
arena: &mut dyn BuilderArena,
Expand Down Expand Up @@ -840,7 +847,7 @@ mod wire_helpers {
data_size: isize,
pointer_count: isize,
) {
ptr::copy_nonoverlapping(src, dst, data_size as usize * BYTES_PER_WORD);
copy_nonoverlapping_check_zero(src, dst, data_size as usize * BYTES_PER_WORD);

let src_refs: *const WirePointer = (src as *const WirePointer).offset(data_size);
let dst_refs: *mut WirePointer = (dst as *mut WirePointer).offset(data_size);
Expand Down Expand Up @@ -909,7 +916,7 @@ mod wire_helpers {
let src_ptr = WirePointer::target(src);
let (dst_ptr, dst, segment_id) =
allocate(arena, dst, segment_id, word_count, WirePointerKind::List);
ptr::copy_nonoverlapping(
copy_nonoverlapping_check_zero(
src_ptr,
dst_ptr,
word_count as usize * BYTES_PER_WORD,
Expand Down Expand Up @@ -1202,7 +1209,7 @@ mod wire_helpers {

// Copy data section.
// Note: copy_nonoverlapping's third argument is an element count, not a byte count.
ptr::copy_nonoverlapping(old_ptr, ptr, old_data_size as usize * BYTES_PER_WORD);
copy_nonoverlapping_check_zero(old_ptr, ptr, old_data_size as usize * BYTES_PER_WORD);

//# Copy pointer section.
let new_pointer_section: *mut WirePointer =
Expand Down Expand Up @@ -1564,7 +1571,7 @@ mod wire_helpers {
let mut dst = new_ptr as *mut WirePointer;
for _ in 0..element_count {
// Copy data section.
ptr::copy_nonoverlapping(src, dst, old_data_size as usize);
copy_nonoverlapping_check_zero(src, dst, old_data_size as usize);

// Copy pointer section
let new_pointer_section = dst.offset(new_data_size as isize);
Expand Down Expand Up @@ -1674,7 +1681,7 @@ mod wire_helpers {
let mut src: *mut u8 = old_ptr;
let old_byte_step = old_data_size / BITS_PER_BYTE as u32;
for _ in 0..element_count {
ptr::copy_nonoverlapping(src, dst, old_byte_step as usize);
copy_nonoverlapping_check_zero(src, dst, old_byte_step as usize);
src = src.offset(old_byte_step as isize);
dst = dst.offset(new_step as isize * BYTES_PER_WORD as isize);
}
Expand Down Expand Up @@ -1831,7 +1838,7 @@ mod wire_helpers {
value: &[u8],
) -> SegmentAnd<data::Builder<'a>> {
let allocation = init_data_pointer(arena, reff, segment_id, value.len() as u32);
ptr::copy_nonoverlapping(value.as_ptr(), allocation.value.as_mut_ptr(), value.len());
copy_nonoverlapping_check_zero(value.as_ptr(), allocation.value.as_mut_ptr(), value.len());
allocation
}

Expand Down Expand Up @@ -1937,7 +1944,7 @@ mod wire_helpers {
*ptr = u8::from(value.get_bool_field(0))
}
} else {
ptr::copy_nonoverlapping::<u8>(value.data, ptr, data_size as usize);
copy_nonoverlapping_check_zero::<u8>(value.data, ptr, data_size as usize);
}

let pointer_section: *mut WirePointer =
Expand Down Expand Up @@ -2032,9 +2039,7 @@ mod wire_helpers {
// in the canonicalize=true case.
let whole_byte_size =
u64::from(value.element_count) * u64::from(value.step) / BITS_PER_BYTE as u64;
if whole_byte_size > 0 {
ptr::copy_nonoverlapping(value.ptr, ptr, whole_byte_size as usize);
}
copy_nonoverlapping_check_zero(value.ptr, ptr, whole_byte_size as usize);

let leftover_bits =
u64::from(value.element_count) * u64::from(value.step) % BITS_PER_BYTE as u64;
Expand Down Expand Up @@ -2114,7 +2119,7 @@ mod wire_helpers {

let mut src: *const u8 = value.ptr;
for _ in 0..value.element_count {
ptr::copy_nonoverlapping(src, dst, data_size as usize * BYTES_PER_WORD);
copy_nonoverlapping_check_zero(src, dst, data_size as usize * BYTES_PER_WORD);
dst = dst.offset(data_size as isize * BYTES_PER_WORD as isize);
src = src.offset(decl_data_size as isize * BYTES_PER_WORD as isize);

Expand Down Expand Up @@ -3738,7 +3743,7 @@ impl<'a> StructBuilder<'a> {
if shared_data_size == 1 {
self.set_bool_field(0, other.get_bool_field(0));
} else {
ptr::copy_nonoverlapping(
wire_helpers::copy_nonoverlapping_check_zero(
other.data,
self.data,
(shared_data_size / BITS_PER_BYTE as u32) as usize,
Expand Down
10 changes: 7 additions & 3 deletions capnpc/test/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1527,16 +1527,20 @@ mod tests {

/// https://github.com/capnproto/capnproto-rust/issues/525
#[test]
fn set_data_list_null() {
fn copy_nonoverlapping_null() {
use crate::test_capnp::test_all_types;

let mut message = message::Builder::new_default();
let root: test_all_types::Builder<'_> = message.init_root();
let mut root: test_all_types::Builder<'_> = message.init_root();

let mut message2 = message::Builder::new_default();
let mut root2: test_all_types::Builder<'_> = message2.init_root();
root2
.set_data_list(root.into_reader().get_data_list().unwrap())
.set_data_list(root.reborrow().into_reader().get_data_list().unwrap())
.unwrap();

root2
.set_struct_field(root.into_reader().get_struct_field().unwrap())
.unwrap();
}

Expand Down

0 comments on commit 7b094f0

Please sign in to comment.