diff --git a/rust/upb.rs b/rust/upb.rs index 5a0c07e47156..94260e6715ae 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -14,9 +14,7 @@ use crate::{ }; use core::fmt::Debug; use std::alloc::Layout; -use std::fmt; use std::mem::{size_of, ManuallyDrop, MaybeUninit}; -use std::ops::Deref; use std::ptr::{self, NonNull}; use std::slice; use std::sync::OnceLock; @@ -60,48 +58,7 @@ impl ScratchSpace { } } -/// Serialized Protobuf wire format data. -/// -/// It's typically produced by `::serialize()`. -pub struct SerializedData { - data: NonNull, - len: usize, - - // The arena that owns `data`. - _arena: Arena, -} - -impl SerializedData { - /// Construct `SerializedData` from raw pointers and its owning arena. - /// - /// # Safety - /// - `arena` must be have allocated `data` - /// - `data` must be readable for `len` bytes and not mutate while this - /// struct exists - pub unsafe fn from_raw_parts(arena: Arena, data: NonNull, len: usize) -> Self { - SerializedData { _arena: arena, data, len } - } - - /// Gets a raw slice pointer. - pub fn as_ptr(&self) -> *const [u8] { - ptr::slice_from_raw_parts(self.data.as_ptr(), self.len) - } -} - -impl Deref for SerializedData { - type Target = [u8]; - fn deref(&self) -> &Self::Target { - // SAFETY: `data` is valid for `len` bytes as promised by - // the caller of `SerializedData::from_raw_parts`. - unsafe { slice::from_raw_parts(self.data.as_ptr(), self.len) } - } -} - -impl fmt::Debug for SerializedData { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Debug::fmt(self.deref(), f) - } -} +pub type SerializedData = upb::OwnedData<[u8]>; impl SettableValue<[u8]> for SerializedData { fn set_on<'msg>(self, _private: Private, mut mutator: Mut<'msg, [u8]>) @@ -814,22 +771,6 @@ mod tests { use super::*; use googletest::prelude::*; - #[test] - fn test_serialized_data_roundtrip() { - let arena = Arena::new(); - let original_data = b"Hello world"; - let len = original_data.len(); - - let serialized_data = unsafe { - SerializedData::from_raw_parts( - arena, - NonNull::new(original_data as *const _ as *mut _).unwrap(), - len, - ) - }; - assert_that!(&*serialized_data, eq(b"Hello world")); - } - #[test] fn assert_c_type_sizes() { // TODO: add these same asserts in C++. diff --git a/rust/upb/BUILD b/rust/upb/BUILD index 894d4950e063..b8f4bfc22185 100644 --- a/rust/upb/BUILD +++ b/rust/upb/BUILD @@ -23,6 +23,7 @@ rust_library( "message_value.rs", "mini_table.rs", "opaque_pointee.rs", + "owned_data.rs", "string_view.rs", "wire.rs", ], diff --git a/rust/upb/arena.rs b/rust/upb/arena.rs index 51e8958e4e48..36be4c449cc9 100644 --- a/rust/upb/arena.rs +++ b/rust/upb/arena.rs @@ -3,7 +3,7 @@ use std::alloc::{self, Layout}; use std::cell::UnsafeCell; use std::marker::PhantomData; use std::mem::{align_of, MaybeUninit}; -use std::ptr::NonNull; +use std::ptr::{self, NonNull}; use std::slice; opaque_pointee!(upb_Arena); @@ -91,6 +91,53 @@ impl Arena { // `UPB_MALLOC_ALIGN` boundary. unsafe { slice::from_raw_parts_mut(ptr.cast(), layout.size()) } } + + /// Same as alloc() but panics if `layout.align() > UPB_MALLOC_ALIGN`. + #[allow(clippy::mut_from_ref)] + #[inline] + pub fn checked_alloc(&self, layout: Layout) -> &mut [MaybeUninit] { + assert!(layout.align() <= UPB_MALLOC_ALIGN); + // SAFETY: layout.align() <= UPB_MALLOC_ALIGN asserted. + unsafe { self.alloc(layout) } + } + + /// Copies the T into this arena and returns a pointer to the T data inside + /// the arena. + pub fn copy_in<'a, T: Copy>(&'a self, data: &T) -> &'a T { + let layout = Layout::for_value(data); + let alloc = self.checked_alloc(layout); + + // SAFETY: + // - alloc is valid for `layout.len()` bytes and is the uninit bytes are written + // to not read from until written. + // - T is copy so copying the bytes of the value is sound. + unsafe { + let alloc = alloc.as_mut_ptr().cast::>(); + // let data = (data as *const T).cast::>(); + (*alloc).write(*data) + } + } + + pub fn copy_str_in<'a>(&'a self, s: &str) -> &'a str { + let copied_bytes = self.copy_slice_in(s.as_bytes()); + // SAFETY: `copied_bytes` has some contents as `s` and so must meet &str + // criteria. + unsafe { std::str::from_utf8_unchecked(copied_bytes) } + } + + pub fn copy_slice_in<'a, T: Copy>(&'a self, data: &[T]) -> &'a [T] { + let layout = Layout::for_value(data); + let alloc: *mut T = self.checked_alloc(layout).as_mut_ptr().cast(); + + // SAFETY: + // - uninit_alloc is valid for `layout.len()` bytes and is the uninit bytes are + // written to not read from until written. + // - T is copy so copying the bytes of the values is sound. + unsafe { + ptr::copy_nonoverlapping(data.as_ptr(), alloc, data.len()); + slice::from_raw_parts_mut(alloc, data.len()) + } + } } impl Default for Arena { diff --git a/rust/upb/lib.rs b/rust/upb/lib.rs index f557d1b0240f..2d9bc3e3698c 100644 --- a/rust/upb/lib.rs +++ b/rust/upb/lib.rs @@ -21,7 +21,9 @@ pub use map::{ }; mod message; -pub use message::{upb_Message, upb_Message_DeepClone, upb_Message_DeepCopy, RawMessage}; +pub use message::{ + upb_Message, upb_Message_DeepClone, upb_Message_DeepCopy, upb_Message_New, RawMessage, +}; mod message_value; pub use message_value::{upb_MessageValue, upb_MutableMessageValue}; @@ -31,8 +33,11 @@ pub use mini_table::{upb_MiniTable, RawMiniTable}; mod opaque_pointee; +mod owned_data; +pub use owned_data::OwnedData; + mod string_view; pub use string_view::StringView; -mod wire; -pub use wire::{upb_Decode, upb_Encode, DecodeStatus, EncodeStatus}; +pub mod wire; +pub use wire::{upb_Decode, DecodeStatus, EncodeStatus}; diff --git a/rust/upb/message.rs b/rust/upb/message.rs index 2a9fe91e3030..fd831d7b6dbe 100644 --- a/rust/upb/message.rs +++ b/rust/upb/message.rs @@ -6,6 +6,10 @@ opaque_pointee!(upb_Message); pub type RawMessage = NonNull; extern "C" { + /// SAFETY: No constraints. + pub fn upb_Message_New(mini_table: *const upb_MiniTable, arena: RawArena) + -> Option; + pub fn upb_Message_DeepCopy( dst: RawMessage, src: RawMessage, diff --git a/rust/upb/owned_data.rs b/rust/upb/owned_data.rs new file mode 100644 index 000000000000..0fefa18af60a --- /dev/null +++ b/rust/upb/owned_data.rs @@ -0,0 +1,89 @@ +use crate::Arena; +use std::fmt::{self, Debug}; +use std::ops::Deref; +use std::ptr::NonNull; + +/// An 'owned' T, conceptually similar to a Box where the T is +/// something in a upb Arena. By holding the data pointer and the owned arena +/// together the data lifetime will be maintained. +pub struct OwnedData { + data: NonNull, + arena: Arena, +} + +impl OwnedData { + /// Construct `OwnedData` from raw pointers and its owning arena. + /// + /// # Safety + /// - `data` must satisfy the safety constraints of pointer::as_ref::<'a>() + /// where 'a is the passed arena's lifetime (`data` must be valid, have + /// lifetime at least as long as `arena`, and must not mutate while this + /// struct exists) + pub unsafe fn new(data: NonNull, arena: Arena) -> Self { + OwnedData { arena, data } + } + + pub fn data(&self) -> *const T { + self.data.as_ptr() + } + + pub fn as_ref(&self) -> &T { + // SAFETY: + // - `data` is valid under the conditions set on ::new(). + unsafe { self.data.as_ref() } + } + + pub fn into_parts(self) -> (NonNull, Arena) { + (self.data, self.arena) + } +} + +impl Deref for OwnedData { + type Target = T; + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +impl AsRef for OwnedData { + fn as_ref(&self) -> &T { + self.as_ref() + } +} + +impl Debug for OwnedData { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Debug::fmt(self.deref(), f) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::str; + + #[test] + fn test_byte_slice_pointer_roundtrip() { + let arena = Arena::new(); + let original_data: &'static [u8] = b"Hello world"; + let owned_data = unsafe { OwnedData::new(original_data.into(), arena) }; + assert_eq!(&*owned_data, b"Hello world"); + } + + #[test] + fn test_alloc_str_roundtrip() { + let arena = Arena::new(); + let s: &str = "Hello"; + let arena_alloc_str: NonNull = arena.copy_str_in(s).into(); + let owned_data = unsafe { OwnedData::new(arena_alloc_str, arena) }; + assert_eq!(&*owned_data, s); + } + + #[test] + fn test_sized_type_roundtrip() { + let arena = Arena::new(); + let arena_alloc_u32: NonNull = arena.copy_in(&7u32).into(); + let owned_data = unsafe { OwnedData::new(arena_alloc_u32, arena) }; + assert_eq!(*owned_data, 7); + } +} diff --git a/rust/upb/wire.rs b/rust/upb/wire.rs index 2b68145cc293..6c7d81a106f4 100644 --- a/rust/upb/wire.rs +++ b/rust/upb/wire.rs @@ -1,8 +1,12 @@ -use crate::{upb_ExtensionRegistry, upb_MiniTable, RawArena, RawMessage}; +use crate::{ + upb_ExtensionRegistry, upb_Message, upb_Message_New, upb_MiniTable, Arena, OwnedData, RawArena, + RawMessage, +}; +use std::ptr::NonNull; // LINT.IfChange(encode_status) #[repr(C)] -#[derive(PartialEq, Eq, Copy, Clone)] +#[derive(PartialEq, Eq, Copy, Clone, Debug)] pub enum EncodeStatus { Ok = 0, OutOfMemory = 1, @@ -13,7 +17,7 @@ pub enum EncodeStatus { // LINT.IfChange(decode_status) #[repr(C)] -#[derive(PartialEq, Eq, Copy, Clone)] +#[derive(PartialEq, Eq, Copy, Clone, Debug)] pub enum DecodeStatus { Ok = 0, Malformed = 1, @@ -25,6 +29,68 @@ pub enum DecodeStatus { } // LINT.ThenChange() +/// If Err, then EncodeStatus != Ok. +/// +/// SAFETY: +/// - `msg` must be associated with `mini_table`. +pub unsafe fn encode( + msg: RawMessage, + mini_table: *const upb_MiniTable, +) -> Result, EncodeStatus> { + let arena = Arena::new(); + let mut buf: *mut u8 = std::ptr::null_mut(); + let mut len = 0usize; + let status = upb_Encode(msg, mini_table, 0, arena.raw(), &mut buf, &mut len); + if status == EncodeStatus::Ok { + assert!(!buf.is_null()); // EncodeStatus Ok should never return NULL data, even for len=0. + // SAFETY: upb guarantees that `buf` is valid to read for `len`. + let slice = NonNull::new_unchecked(std::ptr::slice_from_raw_parts_mut(buf, len)); + Ok(OwnedData::new(slice, arena)) + } else { + Err(status) + } +} + +/// Decodes the provided buffer into a new message. If Err, then DecodeStatus != +/// Ok. +pub fn decode_new( + buf: &[u8], + mini_table: *const upb_MiniTable, +) -> Result, DecodeStatus> { + let arena = Arena::new(); + // SAFETY: No constraints. + let msg = unsafe { upb_Message_New(mini_table, arena.raw()).unwrap() }; + + // SAFETY: `msg` was just created as mutable and associated with `mini_table`. + let result = unsafe { decode(buf, msg, mini_table, &arena) }; + + // SAFETY: + // - `msg` was allocated using `arena. + // - `msg` will not be mutated after this line. + result.map(|_| unsafe { OwnedData::new(msg, arena) }) +} + +/// Decodes into the provided message (merge semantics). If Err, then +/// DecodeStatus != Ok. +/// +/// SAFETY: +/// - `msg` must be mutable. +/// - `msg` must be associated with `mini_table`. +pub unsafe fn decode( + buf: &[u8], + msg: RawMessage, + mini_table: *const upb_MiniTable, + arena: &Arena, +) -> Result<(), DecodeStatus> { + let len = buf.len(); + let buf = buf.as_ptr(); + let status = upb_Decode(buf, len, msg, mini_table, std::ptr::null(), 0, arena.raw()); + match status { + DecodeStatus::Ok => Ok(()), + _ => Err(status), + } +} + extern "C" { pub fn upb_Encode( msg: RawMessage, diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 6ee718fc181c..d9fd68c5fcab 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -68,35 +68,17 @@ void MessageSerialize(Context& ctx, const Descriptor& msg) { case Kernel::kUpb: ctx.Emit({{"minitable", UpbMinitableName(msg)}}, R"rs( - let arena = $pbr$::Arena::new(); // SAFETY: $minitable$ is a static of a const object. let mini_table = unsafe { $std$::ptr::addr_of!($minitable$) }; - let options = 0; - let mut buf: *mut u8 = std::ptr::null_mut(); - let mut len = 0; - - // SAFETY: `mini_table` is the corresponding one that was used to - // construct `self.raw_msg()`. - let status = unsafe { - $pbr$::upb_Encode(self.raw_msg(), mini_table, options, arena.raw(), - &mut buf, &mut len) + // SAFETY: $minitable$ is the one associated with raw_msg(). + let encoded = unsafe { + $pbr$::wire::encode(self.raw_msg(), mini_table) }; //~ TODO: Currently serialize() on the Rust API is an //~ infallible fn, so if upb signals an error here we can only panic. - assert!(status == $pbr$::EncodeStatus::Ok); - let data = if len == 0 { - std::ptr::NonNull::dangling() - } else { - std::ptr::NonNull::new(buf).unwrap() - }; - - // SAFETY: - // - `arena` allocated `data`. - // - `data` is valid for reads up to `len` and will not be mutated. - unsafe { - $pbr$::SerializedData::from_raw_parts(arena, data, len) - } + let serialized = encoded.expect("serialize is not allowed to fail"); + serialized )rs"); return; } @@ -131,27 +113,25 @@ void MessageClearAndParse(Context& ctx, const Descriptor& msg) { let mut msg = Self::new(); // SAFETY: $minitable$ is a static of a const object. let mini_table = unsafe { $std$::ptr::addr_of!($minitable$) }; - let ext_reg = std::ptr::null(); - let options = 0; // SAFETY: // - `data.as_ptr()` is valid to read for `data.len()` // - `mini_table` is the one used to construct `msg.raw_msg()` // - `msg.arena().raw()` is held for the same lifetime as `msg`. let status = unsafe { - $pbr$::upb_Decode( - data.as_ptr(), data.len(), msg.raw_msg(), - mini_table, ext_reg, options, msg.arena().raw()) + $pbr$::wire::decode( + data, msg.raw_msg(), + mini_table, msg.arena()) }; match status { - $pbr$::DecodeStatus::Ok => { + Ok(_) => { //~ This swap causes the old self.inner.arena to be moved into `msg` //~ which we immediately drop, which will release any previous //~ message that was held here. std::mem::swap(self, &mut msg); Ok(()) } - _ => Err($pb$::ParseError) + Err(_) => Err($pb$::ParseError) } )rs"); return;