Skip to content

Commit

Permalink
Create the concept of 'owned data' in upb/rust (similar to MOA but fo…
Browse files Browse the repository at this point in the history
…r arbitrary types) and use that on the wire parse/serialize path.

PiperOrigin-RevId: 626012269
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Apr 18, 2024
1 parent 1d6fdc1 commit 5d22534
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 79 deletions.
44 changes: 1 addition & 43 deletions rust/upb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use core::fmt::Debug;
use std::alloc::Layout;
use std::fmt;
use std::mem::{size_of, ManuallyDrop, MaybeUninit};
use std::ops::Deref;
use std::ptr::{self, NonNull};
use std::slice;
use std::sync::OnceLock;
Expand Down Expand Up @@ -60,48 +59,7 @@ impl ScratchSpace {
}
}

/// Serialized Protobuf wire format data.
///
/// It's typically produced by `<Message>::serialize()`.
pub struct SerializedData {
data: NonNull<u8>,
len: usize,

// The arena that owns `data`.
_arena: Arena,
}

impl SerializedData {
/// Construct `SerializedData` from raw pointers and its owning arena.
///
/// # Safety
/// - `arena` must be have allocated `data`
/// - `data` must be readable for `len` bytes and not mutate while this
/// struct exists
pub unsafe fn from_raw_parts(arena: Arena, data: NonNull<u8>, len: usize) -> Self {
SerializedData { _arena: arena, data, len }
}

/// Gets a raw slice pointer.
pub fn as_ptr(&self) -> *const [u8] {
ptr::slice_from_raw_parts(self.data.as_ptr(), self.len)
}
}

impl Deref for SerializedData {
type Target = [u8];
fn deref(&self) -> &Self::Target {
// SAFETY: `data` is valid for `len` bytes as promised by
// the caller of `SerializedData::from_raw_parts`.
unsafe { slice::from_raw_parts(self.data.as_ptr(), self.len) }
}
}

impl fmt::Debug for SerializedData {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(self.deref(), f)
}
}
pub type SerializedData = upb::OwnedData<[u8]>;

impl SettableValue<[u8]> for SerializedData {
fn set_on<'msg>(self, _private: Private, mut mutator: Mut<'msg, [u8]>)
Expand Down
1 change: 1 addition & 0 deletions rust/upb/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ rust_library(
"message_value.rs",
"mini_table.rs",
"opaque_pointee.rs",
"owned_data.rs",
"string_view.rs",
"wire.rs",
],
Expand Down
11 changes: 8 additions & 3 deletions rust/upb/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ pub use map::{
};

mod message;
pub use message::{upb_Message, upb_Message_DeepClone, upb_Message_DeepCopy, RawMessage};
pub use message::{
upb_Message, upb_Message_DeepClone, upb_Message_DeepCopy, upb_Message_New, RawMessage,
};

mod message_value;
pub use message_value::{upb_MessageValue, upb_MutableMessageValue};
Expand All @@ -31,8 +33,11 @@ pub use mini_table::{upb_MiniTable, RawMiniTable};

mod opaque_pointee;

mod owned_data;
pub use owned_data::OwnedData;

mod string_view;
pub use string_view::StringView;

mod wire;
pub use wire::{upb_Decode, upb_Encode, DecodeStatus, EncodeStatus};
pub mod wire;
pub use wire::{upb_Decode, DecodeStatus, EncodeStatus};
4 changes: 4 additions & 0 deletions rust/upb/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ opaque_pointee!(upb_Message);
pub type RawMessage = NonNull<upb_Message>;

extern "C" {
/// SAFETY: No constraints.
pub fn upb_Message_New(mini_table: *const upb_MiniTable, arena: RawArena)
-> Option<RawMessage>;

pub fn upb_Message_DeepCopy(
dst: RawMessage,
src: RawMessage,
Expand Down
63 changes: 63 additions & 0 deletions rust/upb/owned_data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use crate::Arena;
use std::fmt::{self, Debug};
use std::ops::Deref;

/// An 'owned' T, conceptually similar to a Box<T> where the T is
/// something in a upb Arena. By holding the data pointer and the owned arena
/// together the data lifetime will be maintained.
pub struct OwnedData<T: ?Sized> {
data: *const T,
arena: Arena,
}

impl<T: ?Sized> OwnedData<T> {
/// Construct `OwnedData` from raw pointers and its owning arena.
///
/// # Safety
/// - `data` must have the lifetime of `arena` (e.g. it was allocated using
/// `arena` or other arenas fused to it).
/// - `data` must not mutate while this struct exists
pub unsafe fn new(data: *const T, arena: Arena) -> Self {
OwnedData { arena, data }
}

/// Gets a raw slice pointer.
pub fn data(&self) -> *const T {
self.data
}

pub fn as_ref(&self) -> &T {
// SAFETY:
// - `data` is valid under the conditions set on ::new().
unsafe { self.data.as_ref().unwrap() }
}

pub fn into_parts(self) -> (*const T, Arena) {
(self.data, self.arena)
}
}

impl<T> OwnedData<[T]> {
pub unsafe fn from_raw_parts(arena: Arena, data: std::ptr::NonNull<T>, len: usize) -> Self {
OwnedData::new(std::slice::from_raw_parts(data.as_ptr(), len), arena)
}
}

impl<T: ?Sized> Deref for OwnedData<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}

impl<T: ?Sized> AsRef<T> for OwnedData<T> {
fn as_ref(&self) -> &T {
self.as_ref()
}
}

impl<T: Debug> Debug for OwnedData<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(self.deref(), f)
}
}
72 changes: 69 additions & 3 deletions rust/upb/wire.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use crate::{upb_ExtensionRegistry, upb_MiniTable, RawArena, RawMessage};
use crate::{
upb_ExtensionRegistry, upb_Message, upb_Message_New, upb_MiniTable, Arena, OwnedData, RawArena,
RawMessage,
};
use std::slice;

// LINT.IfChange(encode_status)
#[repr(C)]
#[derive(PartialEq, Eq, Copy, Clone)]
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub enum EncodeStatus {
Ok = 0,
OutOfMemory = 1,
Expand All @@ -13,7 +17,7 @@ pub enum EncodeStatus {

// LINT.IfChange(decode_status)
#[repr(C)]
#[derive(PartialEq, Eq, Copy, Clone)]
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub enum DecodeStatus {
Ok = 0,
Malformed = 1,
Expand All @@ -25,6 +29,68 @@ pub enum DecodeStatus {
}
// LINT.ThenChange()

/// If Err, then EncodeStatus != Ok.
///
/// SAFETY:
/// - `msg` must be associated with `mini_table`.
pub unsafe fn encode(
msg: RawMessage,
mini_table: *const upb_MiniTable,
) -> Result<OwnedData<[u8]>, EncodeStatus> {
let arena = Arena::new();
let mut buf: *mut u8 = std::ptr::null_mut();
let mut len = 0usize;
let status = upb_Encode(msg, mini_table, 0, arena.raw(), &mut buf, &mut len);
if status == EncodeStatus::Ok {
assert!(!buf.is_null()); // EncodeStatus Ok should never return NULL data, even for len=0.
// SAFETY: upb guarantees that `buf` is valid to read for `len`.
let slice = unsafe { slice::from_raw_parts(buf, len) };
Ok(OwnedData::new(slice, arena))
} else {
Err(status)
}
}

/// Decodes the provided buffer into a new message. If Err, then DecodeStatus !=
/// Ok.
pub fn decode_new(
buf: &[u8],
mini_table: *const upb_MiniTable,
) -> Result<OwnedData<upb_Message>, DecodeStatus> {
let arena = Arena::new();
// SAFETY: No constraints.
let msg = unsafe { upb_Message_New(mini_table, arena.raw()).unwrap() };

// SAFETY: `msg` was just created as mutable and associated with `mini_table`.
let result = unsafe { decode(buf, msg, mini_table, &arena) };

// SAFETY:
// - `msg` was allocated using `arena.
// - `msg` will not be mutated after this line.
result.map(|_| unsafe { OwnedData::new(msg.as_ptr(), arena) })
}

/// Decodes into the provided message (merge semantics). If Err, then
/// DecodeStatus != Ok.
///
/// SAFETY:
/// - `msg` must be mutable.
/// - `msg` must be associated with `mini_table`.
pub unsafe fn decode(
buf: &[u8],
msg: RawMessage,
mini_table: *const upb_MiniTable,
arena: &Arena,
) -> Result<(), DecodeStatus> {
let len = buf.len();
let buf = buf.as_ptr();
let status = upb_Decode(buf, len, msg, mini_table, std::ptr::null(), 0, arena.raw());
match status {
DecodeStatus::Ok => Ok(()),
_ => Err(status),
}
}

extern "C" {
pub fn upb_Encode(
msg: RawMessage,
Expand Down
40 changes: 10 additions & 30 deletions src/google/protobuf/compiler/rust/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,35 +68,17 @@ void MessageSerialize(Context& ctx, const Descriptor& msg) {
case Kernel::kUpb:
ctx.Emit({{"minitable", UpbMinitableName(msg)}},
R"rs(
let arena = $pbr$::Arena::new();
// SAFETY: $minitable$ is a static of a const object.
let mini_table = unsafe { $std$::ptr::addr_of!($minitable$) };
let options = 0;
let mut buf: *mut u8 = std::ptr::null_mut();
let mut len = 0;
// SAFETY: `mini_table` is the corresponding one that was used to
// construct `self.raw_msg()`.
let status = unsafe {
$pbr$::upb_Encode(self.raw_msg(), mini_table, options, arena.raw(),
&mut buf, &mut len)
// SAFETY: $minitable$ is the one associated with raw_msg().
let encoded = unsafe {
$pbr$::wire::encode(self.raw_msg(), mini_table)
};
//~ TODO: Currently serialize() on the Rust API is an
//~ infallible fn, so if upb signals an error here we can only panic.
assert!(status == $pbr$::EncodeStatus::Ok);
let data = if len == 0 {
std::ptr::NonNull::dangling()
} else {
std::ptr::NonNull::new(buf).unwrap()
};
// SAFETY:
// - `arena` allocated `data`.
// - `data` is valid for reads up to `len` and will not be mutated.
unsafe {
$pbr$::SerializedData::from_raw_parts(arena, data, len)
}
let serialized = encoded.expect("serialize is not allowed to fail");
serialized
)rs");
return;
}
Expand Down Expand Up @@ -131,27 +113,25 @@ void MessageClearAndParse(Context& ctx, const Descriptor& msg) {
let mut msg = Self::new();
// SAFETY: $minitable$ is a static of a const object.
let mini_table = unsafe { $std$::ptr::addr_of!($minitable$) };
let ext_reg = std::ptr::null();
let options = 0;
// SAFETY:
// - `data.as_ptr()` is valid to read for `data.len()`
// - `mini_table` is the one used to construct `msg.raw_msg()`
// - `msg.arena().raw()` is held for the same lifetime as `msg`.
let status = unsafe {
$pbr$::upb_Decode(
data.as_ptr(), data.len(), msg.raw_msg(),
mini_table, ext_reg, options, msg.arena().raw())
$pbr$::wire::decode(
data, msg.raw_msg(),
mini_table, msg.arena())
};
match status {
$pbr$::DecodeStatus::Ok => {
Ok(_) => {
//~ This swap causes the old self.inner.arena to be moved into `msg`
//~ which we immediately drop, which will release any previous
//~ message that was held here.
std::mem::swap(self, &mut msg);
Ok(())
}
_ => Err($pb$::ParseError)
Err(_) => Err($pb$::ParseError)
}
)rs");
return;
Expand Down

0 comments on commit 5d22534

Please sign in to comment.