From 7746bc637e9b6f76d83259e1eac3a17e5dd3ac06 Mon Sep 17 00:00:00 2001 From: Emanuele Giaquinta Date: Sat, 23 Nov 2024 17:30:13 +0200 Subject: [PATCH] Refactor serializer Replace rmp_serde::Serializer with a simpler implementation derived from it. Signed-off-by: Emanuele Giaquinta --- Cargo.lock | 12 - Cargo.toml | 1 - src/serialize/ext.rs | 5 +- src/serialize/serializer.rs | 489 +++++++++++++++++++++++++++++++++++- 4 files changed, 488 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5f900acd..cd0a2470 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -113,7 +113,6 @@ dependencies = [ "pyo3", "pyo3-build-config", "rmp", - "rmp-serde", "serde", "serde_bytes", "simdutf8", @@ -196,17 +195,6 @@ dependencies = [ "paste", ] -[[package]] -name = "rmp-serde" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" -dependencies = [ - "byteorder", - "rmp", - "serde", -] - [[package]] name = "serde" version = "1.0.215" diff --git a/Cargo.toml b/Cargo.toml index 37d9d6fd..259e1471 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,6 @@ itoa = { version = "1", default-features = false } once_cell = { version = "1", default-features = false, features = ["race"] } pyo3 = { version = "^0.22.6", default-features = false, features = ["extension-module"] } rmp = { version = "^0.8.14", default-features = false, features = ["std"] } -rmp-serde = { version = "1", default-features = false } serde = { version = "1", default-features = false } serde_bytes = { version = "0.11.15", default-features = false, features = ["std"] } simdutf8 = { version = "0.1.5", default-features = false, features = ["std"] } diff --git a/src/serialize/ext.rs b/src/serialize/ext.rs index c866c71a..afc86797 100644 --- a/src/serialize/ext.rs +++ b/src/serialize/ext.rs @@ -31,9 +31,6 @@ impl Serialize for Ext { let length = unsafe { PyBytes_GET_SIZE((*ext).data) as usize }; let data = unsafe { std::slice::from_raw_parts(buffer, length) }; - serializer.serialize_newtype_struct( - rmp_serde::MSGPACK_EXT_STRUCT_NAME, - &(tag as i8, ByteBuf::from(data)), - ) + serializer.serialize_newtype_variant("", tag as u32, "", &ByteBuf::from(data)) } } diff --git a/src/serialize/serializer.rs b/src/serialize/serializer.rs index 9138b042..c9e1a05e 100644 --- a/src/serialize/serializer.rs +++ b/src/serialize/serializer.rs @@ -17,11 +17,496 @@ use crate::serialize::tuple::*; use crate::serialize::uuid::*; use crate::serialize::writer::*; use crate::typeref::*; -use serde::ser::{Serialize, Serializer}; +use serde::ser::{Impossible, Serialize, SerializeMap, SerializeSeq, Serializer}; use std::ptr::NonNull; pub const RECURSION_LIMIT: u8 = 255; +#[derive(Debug)] +pub enum Error { + Custom(String), + Write, +} + +impl std::fmt::Display for Error { + #[cold] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + match *self { + Error::Custom(ref msg) => f.write_str(msg), + Error::Write => f.write_str("write error"), + } + } +} + +impl From for Error { + #[cold] + fn from(_: rmp::encode::ValueWriteError) -> Error { + Error::Write + } +} + +impl serde::ser::Error for Error { + #[cold] + fn custom(msg: T) -> Error + where + T: std::fmt::Display, + { + Error::Custom(msg.to_string()) + } +} + +impl std::error::Error for Error {} + +struct ExtSerializer<'a, W> { + tag: i8, + writer: &'a mut W, +} + +impl<'a, W> ExtSerializer<'a, W> +where + W: std::io::Write, +{ + #[inline] + fn new(tag: i8, writer: &'a mut W) -> Self { + Self { + tag: tag, + writer: writer, + } + } +} + +impl Serializer for &mut ExtSerializer<'_, W> +where + W: std::io::Write, +{ + type Ok = (); + type Error = Error; + + type SerializeSeq = Impossible<(), Error>; + type SerializeTuple = Impossible<(), Error>; + type SerializeTupleStruct = Impossible<(), Error>; + type SerializeTupleVariant = Impossible<(), Error>; + type SerializeMap = Impossible<(), Error>; + type SerializeStruct = Impossible<(), Error>; + type SerializeStructVariant = Impossible<(), Error>; + + fn serialize_bytes(self, value: &[u8]) -> Result { + rmp::encode::write_ext_meta(self.writer, value.len() as u32, self.tag)?; + self.writer.write_all(value).map_err(|_| Error::Write) + } + + fn serialize_bool(self, _value: bool) -> Result { + unreachable!(); + } + + fn serialize_i8(self, _value: i8) -> Result { + unreachable!(); + } + + fn serialize_i16(self, _value: i16) -> Result { + unreachable!(); + } + + fn serialize_i32(self, _value: i32) -> Result { + unreachable!(); + } + + fn serialize_i64(self, _value: i64) -> Result { + unreachable!(); + } + + fn serialize_u8(self, _value: u8) -> Result { + unreachable!(); + } + + fn serialize_u16(self, _value: u16) -> Result { + unreachable!(); + } + + fn serialize_u32(self, _value: u32) -> Result { + unreachable!(); + } + + fn serialize_u64(self, _value: u64) -> Result { + unreachable!(); + } + + fn serialize_f32(self, _value: f32) -> Result { + unreachable!(); + } + + fn serialize_f64(self, _value: f64) -> Result { + unreachable!(); + } + + fn serialize_char(self, _value: char) -> Result { + unreachable!(); + } + + fn serialize_str(self, _value: &str) -> Result { + unreachable!(); + } + + fn serialize_none(self) -> Result { + unreachable!(); + } + + fn serialize_some(self, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + unreachable!(); + } + + fn serialize_unit(self) -> Result { + unreachable!(); + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + unreachable!(); + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + unreachable!(); + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + _value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + unreachable!(); + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + unreachable!(); + } + + fn serialize_seq(self, _len: Option) -> Result { + unreachable!(); + } + + fn serialize_tuple(self, _len: usize) -> Result { + unreachable!(); + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + unreachable!(); + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unreachable!(); + } + + fn serialize_map(self, _len: Option) -> Result { + unreachable!(); + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + unreachable!(); + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unreachable!(); + } +} + +pub struct MessagePackSerializer { + writer: W, +} + +impl MessagePackSerializer +where + W: std::io::Write, +{ + #[inline] + pub fn new(writer: W) -> Self { + MessagePackSerializer { writer } + } +} + +pub struct Compound<'a, W> { + se: &'a mut MessagePackSerializer, +} + +impl SerializeSeq for Compound<'_, W> +where + W: std::io::Write, +{ + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut *self.se) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl SerializeMap for Compound<'_, W> +where + W: std::io::Write, +{ + type Ok = (); + type Error = Error; + + fn serialize_key(&mut self, key: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + key.serialize(&mut *self.se) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut *self.se) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W> Serializer for &'a mut MessagePackSerializer +where + W: std::io::Write, +{ + type Ok = (); + type Error = Error; + + type SerializeSeq = Compound<'a, W>; + type SerializeTuple = Impossible<(), Error>; + type SerializeTupleStruct = Impossible<(), Error>; + type SerializeTupleVariant = Impossible<(), Error>; + type SerializeMap = Compound<'a, W>; + type SerializeStruct = Impossible<(), Error>; + type SerializeStructVariant = Impossible<(), Error>; + + fn serialize_bool(self, value: bool) -> Result { + rmp::encode::write_bool(&mut self.writer, value).map_err(|_| Error::Write) + } + + fn serialize_i8(self, value: i8) -> Result { + self.serialize_i64(i64::from(value)) + } + + fn serialize_i16(self, value: i16) -> Result { + self.serialize_i64(i64::from(value)) + } + + fn serialize_i32(self, value: i32) -> Result { + self.serialize_i64(i64::from(value)) + } + + fn serialize_i64(self, value: i64) -> Result { + rmp::encode::write_sint(&mut self.writer, value)?; + Ok(()) + } + + fn serialize_i128(self, value: i128) -> Result { + self.serialize_bytes(&value.to_be_bytes()) + } + + fn serialize_u8(self, value: u8) -> Result { + self.serialize_u64(u64::from(value)) + } + + fn serialize_u16(self, value: u16) -> Result { + self.serialize_u64(u64::from(value)) + } + + fn serialize_u32(self, value: u32) -> Result { + self.serialize_u64(u64::from(value)) + } + + fn serialize_u64(self, value: u64) -> Result { + rmp::encode::write_uint(&mut self.writer, value)?; + Ok(()) + } + + fn serialize_u128(self, value: u128) -> Result { + self.serialize_bytes(&value.to_be_bytes()) + } + + fn serialize_f32(self, value: f32) -> Result { + rmp::encode::write_f32(&mut self.writer, value)?; + Ok(()) + } + + fn serialize_f64(self, value: f64) -> Result { + rmp::encode::write_f64(&mut self.writer, value)?; + Ok(()) + } + + fn serialize_char(self, _value: char) -> Result { + unreachable!(); + } + + fn serialize_str(self, value: &str) -> Result { + rmp::encode::write_str(&mut self.writer, value)?; + Ok(()) + } + + fn serialize_bytes(self, value: &[u8]) -> Result { + rmp::encode::write_bin(&mut self.writer, value)?; + Ok(()) + } + + fn serialize_none(self) -> Result<(), Self::Error> { + self.serialize_unit() + } + + fn serialize_some(self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + fn serialize_unit(self) -> Result { + rmp::encode::write_nil(&mut self.writer).map_err(|_| Error::Write) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + unreachable!(); + } + + fn serialize_unit_variant( + self, + _name: &str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + unreachable!(); + } + + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + unreachable!(); + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + let tag: i8 = variant_index.try_into().unwrap_or_else(|_| unreachable!()); + let mut ext_se = ExtSerializer::new(tag, &mut self.writer); + value.serialize(&mut ext_se) + } + + fn serialize_seq(self, len: Option) -> Result { + match len { + Some(len) => { + rmp::encode::write_array_len(&mut self.writer, len as u32)?; + Ok(Compound { se: self }) + } + None => unreachable!(), + } + } + + fn serialize_tuple(self, _len: usize) -> Result { + unreachable!(); + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + unreachable!(); + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unreachable!(); + } + + fn serialize_map(self, len: Option) -> Result { + match len { + Some(len) => { + rmp::encode::write_map_len(&mut self.writer, len as u32)?; + Ok(Compound { se: self }) + } + None => unreachable!(), + } + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + unreachable!(); + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unreachable!(); + } +} + pub fn serialize( ptr: *mut pyo3::ffi::PyObject, default: Option>, @@ -29,7 +514,7 @@ pub fn serialize( ) -> Result, String> { let mut buf = BytesWriter::default(); let obj = PyObject::new(ptr, opts, 0, 0, default); - let mut ser = rmp_serde::Serializer::new(&mut buf); + let mut ser = MessagePackSerializer::new(&mut buf); let res = obj.serialize(&mut ser); match res { Ok(_) => Ok(buf.finish()),