Skip to content

Commit

Permalink
Create the concept of 'owned data' in upb/rust as a generalization of…
Browse files Browse the repository at this point in the history
… the upb.rs SerializedData (which is a arena + data for arbitrary types, both thin and wide ref types), use that for the wire parse/serialize path.

PiperOrigin-RevId: 626012269
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Apr 19, 2024
1 parent 0cda26d commit 332850e
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 97 deletions.
61 changes: 1 addition & 60 deletions rust/upb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ use crate::{
};
use core::fmt::Debug;
use std::alloc::Layout;
use std::fmt;
use std::mem::{size_of, ManuallyDrop, MaybeUninit};
use std::ops::Deref;
use std::ptr::{self, NonNull};
use std::slice;
use std::sync::OnceLock;
Expand Down Expand Up @@ -60,48 +58,7 @@ impl ScratchSpace {
}
}

/// Serialized Protobuf wire format data.
///
/// It's typically produced by `<Message>::serialize()`.
pub struct SerializedData {
data: NonNull<u8>,
len: usize,

// The arena that owns `data`.
_arena: Arena,
}

impl SerializedData {
/// Construct `SerializedData` from raw pointers and its owning arena.
///
/// # Safety
/// - `arena` must be have allocated `data`
/// - `data` must be readable for `len` bytes and not mutate while this
/// struct exists
pub unsafe fn from_raw_parts(arena: Arena, data: NonNull<u8>, len: usize) -> Self {
SerializedData { _arena: arena, data, len }
}

/// Gets a raw slice pointer.
pub fn as_ptr(&self) -> *const [u8] {
ptr::slice_from_raw_parts(self.data.as_ptr(), self.len)
}
}

impl Deref for SerializedData {
type Target = [u8];
fn deref(&self) -> &Self::Target {
// SAFETY: `data` is valid for `len` bytes as promised by
// the caller of `SerializedData::from_raw_parts`.
unsafe { slice::from_raw_parts(self.data.as_ptr(), self.len) }
}
}

impl fmt::Debug for SerializedData {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(self.deref(), f)
}
}
pub type SerializedData = upb::OwnedData<[u8]>;

impl SettableValue<[u8]> for SerializedData {
fn set_on<'msg>(self, _private: Private, mut mutator: Mut<'msg, [u8]>)
Expand Down Expand Up @@ -814,22 +771,6 @@ mod tests {
use super::*;
use googletest::prelude::*;

#[test]
fn test_serialized_data_roundtrip() {
let arena = Arena::new();
let original_data = b"Hello world";
let len = original_data.len();

let serialized_data = unsafe {
SerializedData::from_raw_parts(
arena,
NonNull::new(original_data as *const _ as *mut _).unwrap(),
len,
)
};
assert_that!(&*serialized_data, eq(b"Hello world"));
}

#[test]
fn assert_c_type_sizes() {
// TODO: add these same asserts in C++.
Expand Down
1 change: 1 addition & 0 deletions rust/upb/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ rust_library(
"message_value.rs",
"mini_table.rs",
"opaque_pointee.rs",
"owned_data.rs",
"string_view.rs",
"wire.rs",
],
Expand Down
49 changes: 48 additions & 1 deletion rust/upb/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::alloc::{self, Layout};
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::mem::{align_of, MaybeUninit};
use std::ptr::NonNull;
use std::ptr::{self, NonNull};
use std::slice;

opaque_pointee!(upb_Arena);
Expand Down Expand Up @@ -91,6 +91,53 @@ impl Arena {
// `UPB_MALLOC_ALIGN` boundary.
unsafe { slice::from_raw_parts_mut(ptr.cast(), layout.size()) }
}

/// Same as alloc() but panics if `layout.align() > UPB_MALLOC_ALIGN`.
#[allow(clippy::mut_from_ref)]
#[inline]
pub fn checked_alloc(&self, layout: Layout) -> &mut [MaybeUninit<u8>] {
assert!(layout.align() <= UPB_MALLOC_ALIGN);
// SAFETY: layout.align() <= UPB_MALLOC_ALIGN asserted.
unsafe { self.alloc(layout) }
}

/// Copies the T into this arena and returns a pointer to the T data inside
/// the arena.
pub fn copy_in<'a, T: Copy>(&'a self, data: &T) -> &'a T {
let layout = Layout::for_value(data);
let alloc = self.checked_alloc(layout);

// SAFETY:
// - alloc is valid for `layout.len()` bytes and is the uninit bytes are written
// to not read from until written.
// - T is copy so copying the bytes of the value is sound.
unsafe {
let alloc = alloc.as_mut_ptr().cast::<MaybeUninit<T>>();
// let data = (data as *const T).cast::<MaybeUninit<T>>();
(*alloc).write(*data)
}
}

pub fn copy_str_in<'a>(&'a self, s: &str) -> &'a str {
let copied_bytes = self.copy_slice_in(s.as_bytes());
// SAFETY: `copied_bytes` has some contents as `s` and so must meet &str
// criteria.
unsafe { std::str::from_utf8_unchecked(copied_bytes) }
}

pub fn copy_slice_in<'a, T: Copy>(&'a self, data: &[T]) -> &'a [T] {
let layout = Layout::for_value(data);
let alloc: *mut T = self.checked_alloc(layout).as_mut_ptr().cast();

// SAFETY:
// - uninit_alloc is valid for `layout.len()` bytes and is the uninit bytes are
// written to not read from until written.
// - T is copy so copying the bytes of the values is sound.
unsafe {
ptr::copy_nonoverlapping(data.as_ptr(), alloc, data.len());
slice::from_raw_parts_mut(alloc, data.len())
}
}
}

impl Default for Arena {
Expand Down
11 changes: 8 additions & 3 deletions rust/upb/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ pub use map::{
};

mod message;
pub use message::{upb_Message, upb_Message_DeepClone, upb_Message_DeepCopy, RawMessage};
pub use message::{
upb_Message, upb_Message_DeepClone, upb_Message_DeepCopy, upb_Message_New, RawMessage,
};

mod message_value;
pub use message_value::{upb_MessageValue, upb_MutableMessageValue};
Expand All @@ -31,8 +33,11 @@ pub use mini_table::{upb_MiniTable, RawMiniTable};

mod opaque_pointee;

mod owned_data;
pub use owned_data::OwnedData;

mod string_view;
pub use string_view::StringView;

mod wire;
pub use wire::{upb_Decode, upb_Encode, DecodeStatus, EncodeStatus};
pub mod wire;
pub use wire::{upb_Decode, DecodeStatus, EncodeStatus};
4 changes: 4 additions & 0 deletions rust/upb/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ opaque_pointee!(upb_Message);
pub type RawMessage = NonNull<upb_Message>;

extern "C" {
/// SAFETY: No constraints.
pub fn upb_Message_New(mini_table: *const upb_MiniTable, arena: RawArena)
-> Option<RawMessage>;

pub fn upb_Message_DeepCopy(
dst: RawMessage,
src: RawMessage,
Expand Down
89 changes: 89 additions & 0 deletions rust/upb/owned_data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use crate::Arena;
use std::fmt::{self, Debug};
use std::ops::Deref;
use std::ptr::NonNull;

/// An 'owned' T, conceptually similar to a Box<T> where the T is
/// something in a upb Arena. By holding the data pointer and the owned arena
/// together the data lifetime will be maintained.
pub struct OwnedData<T: ?Sized> {
data: NonNull<T>,
arena: Arena,
}

impl<T: ?Sized> OwnedData<T> {
/// Construct `OwnedData` from raw pointers and its owning arena.
///
/// # Safety
/// - `data` must satisfy the safety constraints of pointer::as_ref::<'a>()
/// where 'a is the passed arena's lifetime (`data` must be valid, have
/// lifetime at least as long as `arena`, and must not mutate while this
/// struct exists)
pub unsafe fn new(data: NonNull<T>, arena: Arena) -> Self {
OwnedData { arena, data }
}

pub fn data(&self) -> *const T {
self.data.as_ptr()
}

pub fn as_ref(&self) -> &T {
// SAFETY:
// - `data` is valid under the conditions set on ::new().
unsafe { self.data.as_ref() }
}

pub fn into_parts(self) -> (NonNull<T>, Arena) {
(self.data, self.arena)
}
}

impl<T: ?Sized> Deref for OwnedData<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}

impl<T: ?Sized> AsRef<T> for OwnedData<T> {
fn as_ref(&self) -> &T {
self.as_ref()
}
}

impl<T: Debug> Debug for OwnedData<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(self.deref(), f)
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::str;

#[test]
fn test_byte_slice_pointer_roundtrip() {
let arena = Arena::new();
let original_data: &'static [u8] = b"Hello world";
let owned_data = unsafe { OwnedData::new(original_data.into(), arena) };
assert_eq!(&*owned_data, b"Hello world");
}

#[test]
fn test_alloc_str_roundtrip() {
let arena = Arena::new();
let s: &str = "Hello";
let arena_alloc_str: NonNull<str> = arena.copy_str_in(s).into();
let owned_data = unsafe { OwnedData::new(arena_alloc_str, arena) };
assert_eq!(&*owned_data, s);
}

#[test]
fn test_sized_type_roundtrip() {
let arena = Arena::new();
let arena_alloc_u32: NonNull<u32> = arena.copy_in(&7u32).into();
let owned_data = unsafe { OwnedData::new(arena_alloc_u32, arena) };
assert_eq!(*owned_data, 7);
}
}
72 changes: 69 additions & 3 deletions rust/upb/wire.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use crate::{upb_ExtensionRegistry, upb_MiniTable, RawArena, RawMessage};
use crate::{
upb_ExtensionRegistry, upb_Message, upb_Message_New, upb_MiniTable, Arena, OwnedData, RawArena,
RawMessage,
};
use std::ptr::NonNull;

// LINT.IfChange(encode_status)
#[repr(C)]
#[derive(PartialEq, Eq, Copy, Clone)]
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub enum EncodeStatus {
Ok = 0,
OutOfMemory = 1,
Expand All @@ -13,7 +17,7 @@ pub enum EncodeStatus {

// LINT.IfChange(decode_status)
#[repr(C)]
#[derive(PartialEq, Eq, Copy, Clone)]
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub enum DecodeStatus {
Ok = 0,
Malformed = 1,
Expand All @@ -25,6 +29,68 @@ pub enum DecodeStatus {
}
// LINT.ThenChange()

/// If Err, then EncodeStatus != Ok.
///
/// SAFETY:
/// - `msg` must be associated with `mini_table`.
pub unsafe fn encode(
msg: RawMessage,
mini_table: *const upb_MiniTable,
) -> Result<OwnedData<[u8]>, EncodeStatus> {
let arena = Arena::new();
let mut buf: *mut u8 = std::ptr::null_mut();
let mut len = 0usize;
let status = upb_Encode(msg, mini_table, 0, arena.raw(), &mut buf, &mut len);
if status == EncodeStatus::Ok {
assert!(!buf.is_null()); // EncodeStatus Ok should never return NULL data, even for len=0.
// SAFETY: upb guarantees that `buf` is valid to read for `len`.
let slice = NonNull::new_unchecked(std::ptr::slice_from_raw_parts_mut(buf, len));
Ok(OwnedData::new(slice, arena))
} else {
Err(status)
}
}

/// Decodes the provided buffer into a new message. If Err, then DecodeStatus !=
/// Ok.
pub fn decode_new(
buf: &[u8],
mini_table: *const upb_MiniTable,
) -> Result<OwnedData<RawMessage>, DecodeStatus> {
let arena = Arena::new();
// SAFETY: No constraints.
let msg = unsafe { upb_Message_New(mini_table, arena.raw()).unwrap() };

// SAFETY: `msg` was just created as mutable and associated with `mini_table`.
let result = unsafe { decode(buf, msg, mini_table, &arena) };

// SAFETY:
// - `msg` was allocated using `arena.
// - `msg` will not be mutated after this line.
result.map(|_| unsafe { OwnedData::new(msg, arena) })
}

/// Decodes into the provided message (merge semantics). If Err, then
/// DecodeStatus != Ok.
///
/// SAFETY:
/// - `msg` must be mutable.
/// - `msg` must be associated with `mini_table`.
pub unsafe fn decode(
buf: &[u8],
msg: RawMessage,
mini_table: *const upb_MiniTable,
arena: &Arena,
) -> Result<(), DecodeStatus> {
let len = buf.len();
let buf = buf.as_ptr();
let status = upb_Decode(buf, len, msg, mini_table, std::ptr::null(), 0, arena.raw());
match status {
DecodeStatus::Ok => Ok(()),
_ => Err(status),
}
}

extern "C" {
pub fn upb_Encode(
msg: RawMessage,
Expand Down
Loading

0 comments on commit 332850e

Please sign in to comment.