From af30da645f05ef2d98a117fbb879f70c20909af3 Mon Sep 17 00:00:00 2001 From: David Edey Date: Wed, 25 Sep 2024 19:18:45 +0100 Subject: [PATCH 1/3] tweak: Use a visitor pattern for more efficient traversals And make VecTraverser use a shim to the new visitor pattern to validate it in tests --- radix-common/src/data/manifest/definitions.rs | 1 + radix-common/src/data/scrypto/definitions.rs | 1 + sbor/src/basic.rs | 1 + sbor/src/lib.rs | 7 + sbor/src/traversal/typed/typed_traverser.rs | 2 + .../untyped/event_stream_traverser.rs | 484 +++++++++++++ sbor/src/traversal/untyped/mod.rs | 29 +- .../src/traversal/untyped/traversal_traits.rs | 111 +++ .../{traverser.rs => untyped_traverser.rs} | 637 +++++++++--------- .../src/traversal/untyped/utility_visitors.rs | 80 +++ sbor/src/vec_traits.rs | 5 +- 11 files changed, 1047 insertions(+), 311 deletions(-) create mode 100644 sbor/src/traversal/untyped/event_stream_traverser.rs create mode 100644 sbor/src/traversal/untyped/traversal_traits.rs rename sbor/src/traversal/untyped/{traverser.rs => untyped_traverser.rs} (56%) create mode 100644 sbor/src/traversal/untyped/utility_visitors.rs diff --git a/radix-common/src/data/manifest/definitions.rs b/radix-common/src/data/manifest/definitions.rs index bf803b2bc24..b1e9b8adb5f 100644 --- a/radix-common/src/data/manifest/definitions.rs +++ b/radix-common/src/data/manifest/definitions.rs @@ -10,6 +10,7 @@ pub type ManifestDecoder<'a> = VecDecoder<'a, ManifestCustomValueKind>; pub type ManifestValueKind = ValueKind; pub type ManifestValue = Value; pub type ManifestEnumVariantValue = EnumVariantValue; +#[allow(deprecated)] pub type ManifestTraverser<'a> = VecTraverser<'a, ManifestCustomTraversal>; pub trait ManifestCategorize: Categorize {} diff --git a/radix-common/src/data/scrypto/definitions.rs b/radix-common/src/data/scrypto/definitions.rs index f79caae750b..25c5ffb23b0 100644 --- a/radix-common/src/data/scrypto/definitions.rs +++ b/radix-common/src/data/scrypto/definitions.rs @@ -7,6 +7,7 @@ pub use crate::constants::SCRYPTO_SBOR_V1_PAYLOAD_PREFIX; pub type ScryptoEncoder<'a> = VecEncoder<'a, ScryptoCustomValueKind>; pub type ScryptoDecoder<'a> = VecDecoder<'a, ScryptoCustomValueKind>; +#[allow(deprecated)] pub type ScryptoTraverser<'a> = VecTraverser<'a, ScryptoCustomTraversal>; pub type ScryptoValueKind = ValueKind; pub type ScryptoValue = Value; diff --git a/sbor/src/basic.rs b/sbor/src/basic.rs index 27e50a52bf8..1deef0d703d 100644 --- a/sbor/src/basic.rs +++ b/sbor/src/basic.rs @@ -27,6 +27,7 @@ impl CustomValue for NoCustomValue { pub type BasicEncoder<'a> = VecEncoder<'a, NoCustomValueKind>; pub type BasicDecoder<'a> = VecDecoder<'a, NoCustomValueKind>; +#[allow(deprecated)] pub type BasicTraverser<'a> = VecTraverser<'a, NoCustomTraversal>; pub type BasicValue = Value; pub type BasicValueKind = ValueKind; diff --git a/sbor/src/lib.rs b/sbor/src/lib.rs index f2bd9fab953..5f16fae30a8 100644 --- a/sbor/src/lib.rs +++ b/sbor/src/lib.rs @@ -122,5 +122,12 @@ pub(crate) mod internal_prelude { pub use crate::prelude::*; // These are mostly used for more advanced use cases, // so aren't included in the general prelude + pub use crate::basic::*; + pub use crate::decoder::*; + pub use crate::encoder::*; + pub use crate::payload_validation::*; + pub use crate::schema::*; + pub use crate::traversal::*; pub use crate::vec_traits::*; + pub use core::ops::ControlFlow; } diff --git a/sbor/src/traversal/typed/typed_traverser.rs b/sbor/src/traversal/typed/typed_traverser.rs index ed61d4eb938..43082cbe2b6 100644 --- a/sbor/src/traversal/typed/typed_traverser.rs +++ b/sbor/src/traversal/typed/typed_traverser.rs @@ -44,6 +44,7 @@ pub fn traverse_partial_payload_with_types<'de, 's, E: CustomExtension>( /// It validates that the payload matches the given type kinds, /// and adds the relevant type index to the events which are output. pub struct TypedTraverser<'de, 's, E: CustomExtension> { + #[allow(deprecated)] traverser: VecTraverser<'de, E::CustomTraversal>, state: TypedTraverserState<'s, E>, } @@ -114,6 +115,7 @@ macro_rules! look_up_type { }; } +#[allow(deprecated)] // Allow use of deprecated VecTraverser impl<'de, 's, E: CustomExtension> TypedTraverser<'de, 's, E> { pub fn new( input: &'de [u8], diff --git a/sbor/src/traversal/untyped/event_stream_traverser.rs b/sbor/src/traversal/untyped/event_stream_traverser.rs new file mode 100644 index 00000000000..8265ea836a7 --- /dev/null +++ b/sbor/src/traversal/untyped/event_stream_traverser.rs @@ -0,0 +1,484 @@ +use crate::internal_prelude::*; + +// ================= +// DEPRECATION NOTES +// ================= +// Once we no longer need this (because we've moved to the visitor model), this opens up +// a world of further optimisations: +// * We can change it so that issuing a ControlFlow::Break aborts the process +// * Can get rid of `NextAction` and Step completely +// * And instead, just have standard top-down traversal logic +// * We can avoid storing a `resume_action` on the events + +/// The `VecTraverser` is for streamed decoding of a payload or single encoded value (tree). +/// It turns payload decoding into a pull-based event stream. +/// +/// The caller is responsible for stopping calling `next_event` after an Error or End event. +#[deprecated = "Use UntypedTraverser which uses the visitor pattern and is more efficient"] +pub struct VecTraverser<'de, T: CustomTraversal> { + untyped_traverser: UntypedTraverser<'de, T>, + visitor: EventStreamVisitor<'de, T>, +} + +pub struct VecTraverserConfig { + pub max_depth: usize, + pub check_exact_end: bool, +} + +#[allow(deprecated)] +impl<'de, T: CustomTraversal> VecTraverser<'de, T> { + pub fn new( + input: &'de [u8], + expected_start: ExpectedStart, + config: VecTraverserConfig, + ) -> Self { + let config = UntypedTraverserConfig { + max_depth: config.max_depth, + check_exact_end: config.check_exact_end, + }; + let untyped_traverser = UntypedTraverser::::new(input, config); + Self { + untyped_traverser, + visitor: EventStreamVisitor { + next_action: SuspendableNextAction::Action(expected_start.into_starting_action()), + next_event: None, + }, + } + } + + pub fn next_event<'t>(&'t mut self) -> LocatedTraversalEvent<'t, 'de, T> { + match self.visitor.next_action { + SuspendableNextAction::Action(next_action) => { + let location = self + .untyped_traverser + .continue_traversal_from(next_action, &mut self.visitor); + LocatedTraversalEvent { + location, + event: self + .visitor + .next_event + .take() + .expect("Visitor always expected to populate an event"), + } + } + SuspendableNextAction::Errored => panic!("Can't get next event as already errored"), + SuspendableNextAction::Ended => todo!("Can't get next event as already ended"), + } + } +} + +#[cfg(test)] +mod tests { + use crate::internal_prelude::*; + + #[derive(Categorize, Encode)] + #[allow(dead_code)] + struct TestStruct { + x: u32, + } + + #[derive(Categorize, Encode)] + #[allow(dead_code)] + enum TestEnum { + A { x: u32 }, + B(u32), + C, + } + + #[test] + pub fn test_calculate_value_tree_body_byte_array() { + let payload = basic_encode(&BasicValue::Array { + element_value_kind: BasicValueKind::Array, + elements: vec![BasicValue::Array { + element_value_kind: BasicValueKind::U8, + elements: vec![BasicValue::U8 { value: 44 }, BasicValue::U8 { value: 55 }], + }], + }) + .unwrap(); + /* + 91 - prefix + 32 - value kind: array + 32 - element value kind: array + 1 - number of elements: 1 + 7 - element value kind: u8 + 2 - number of elements: u8 + 44 - u8 + 55 - u8 + */ + let length = calculate_value_tree_body_byte_length::( + &payload[2..], + BasicValueKind::Array, + 0, + 100, + ) + .unwrap(); + assert_eq!(length, 6); + let length = calculate_value_tree_body_byte_length::( + &payload[6..], + BasicValueKind::U8, + 0, + 100, + ) + .unwrap(); + assert_eq!(length, 1); + } + + #[test] + pub fn test_exact_events_returned() { + let payload = basic_encode(&( + 2u8, + vec![3u8, 7u8], + (3u32, indexmap!(16u8 => 18u32)), + TestEnum::B(4u32), + Vec::::new(), + Vec::::new(), + vec![vec![(-2i64,)]], + )) + .unwrap(); + + let mut traverser = basic_payload_traverser(&payload); + + // Start: + next_event_is_container_start_header( + &mut traverser, + ContainerHeader::Tuple(TupleHeader { length: 7 }), + 1, + 1, + 3, + ); + // First line + next_event_is_terminal_value(&mut traverser, TerminalValueRef::U8(2), 2, 3, 5); + // Second line + next_event_is_container_start_header( + &mut traverser, + ContainerHeader::Array(ArrayHeader { + element_value_kind: ValueKind::U8, + length: 2, + }), + 2, + 5, + 8, + ); + next_event_is_terminal_value_slice( + &mut traverser, + TerminalValueBatchRef::U8(&[3u8, 7u8]), + 3, + 8, + 10, + ); + next_event_is_container_end( + &mut traverser, + ContainerHeader::Array(ArrayHeader { + element_value_kind: ValueKind::U8, + length: 2, + }), + 2, + 5, + 10, + ); + // Third line + next_event_is_container_start_header( + &mut traverser, + ContainerHeader::Tuple(TupleHeader { length: 2 }), + 2, + 10, + 12, + ); + next_event_is_terminal_value(&mut traverser, TerminalValueRef::U32(3), 3, 12, 17); + next_event_is_container_start_header( + &mut traverser, + ContainerHeader::Map(MapHeader { + key_value_kind: ValueKind::U8, + value_value_kind: ValueKind::U32, + length: 1, + }), + 3, + 17, + 21, + ); + next_event_is_terminal_value(&mut traverser, TerminalValueRef::U8(16), 4, 21, 22); + next_event_is_terminal_value(&mut traverser, TerminalValueRef::U32(18), 4, 22, 26); + next_event_is_container_end( + &mut traverser, + ContainerHeader::Map(MapHeader { + key_value_kind: ValueKind::U8, + value_value_kind: ValueKind::U32, + length: 1, + }), + 3, + 17, + 26, + ); + next_event_is_container_end( + &mut traverser, + ContainerHeader::Tuple(TupleHeader { length: 2 }), + 2, + 10, + 26, + ); + // Fourth line + next_event_is_container_start_header( + &mut traverser, + ContainerHeader::EnumVariant(EnumVariantHeader { + variant: 1, + length: 1, + }), + 2, + 26, + 29, + ); + next_event_is_terminal_value(&mut traverser, TerminalValueRef::U32(4), 3, 29, 34); + next_event_is_container_end( + &mut traverser, + ContainerHeader::EnumVariant(EnumVariantHeader { + variant: 1, + length: 1, + }), + 2, + 26, + 34, + ); + // Fifth line - empty Vec - no bytes event is output + next_event_is_container_start_header( + &mut traverser, + ContainerHeader::Array(ArrayHeader { + element_value_kind: ValueKind::U8, + length: 0, + }), + 2, + 34, + 37, + ); + next_event_is_container_end( + &mut traverser, + ContainerHeader::Array(ArrayHeader { + element_value_kind: ValueKind::U8, + length: 0, + }), + 2, + 34, + 37, + ); + // Sixth line - empty Vec + next_event_is_container_start_header( + &mut traverser, + ContainerHeader::Array(ArrayHeader { + element_value_kind: ValueKind::I32, + length: 0, + }), + 2, + 37, + 40, + ); + next_event_is_container_end( + &mut traverser, + ContainerHeader::Array(ArrayHeader { + element_value_kind: ValueKind::I32, + length: 0, + }), + 2, + 37, + 40, + ); + // Seventh line - Vec> + next_event_is_container_start_header( + &mut traverser, + ContainerHeader::Array(ArrayHeader { + element_value_kind: ValueKind::Array, + length: 1, + }), + 2, + 40, + 43, + ); + next_event_is_container_start_header( + &mut traverser, + ContainerHeader::Array(ArrayHeader { + element_value_kind: ValueKind::Tuple, + length: 1, + }), + 3, + 43, + 45, + ); + next_event_is_container_start_header( + &mut traverser, + ContainerHeader::Tuple(TupleHeader { length: 1 }), + 4, + 45, + 46, + ); + next_event_is_terminal_value(&mut traverser, TerminalValueRef::I64(-2), 5, 46, 55); + next_event_is_container_end( + &mut traverser, + ContainerHeader::Tuple(TupleHeader { length: 1 }), + 4, + 45, + 55, + ); + next_event_is_container_end( + &mut traverser, + ContainerHeader::Array(ArrayHeader { + element_value_kind: ValueKind::Tuple, + length: 1, + }), + 3, + 43, + 55, + ); + next_event_is_container_end( + &mut traverser, + ContainerHeader::Array(ArrayHeader { + element_value_kind: ValueKind::Array, + length: 1, + }), + 2, + 40, + 55, + ); + + // End + next_event_is_container_end( + &mut traverser, + ContainerHeader::Tuple(TupleHeader { length: 7 }), + 1, + 1, + 55, + ); + next_event_is_end(&mut traverser, 55, 55); + } + + pub fn next_event_is_container_start_header( + traverser: &mut BasicTraverser, + expected_header: ContainerHeader, + expected_depth: usize, + expected_start_offset: usize, + expected_end_offset: usize, + ) { + let event = traverser.next_event(); + let sbor_depth = event.location.ancestor_path.len() + 1; + let LocatedTraversalEvent { + event: TraversalEvent::ContainerStart(header), + location: + Location { + start_offset, + end_offset, + .. + }, + } = event + else { + panic!("Invalid event - expected ContainerStart, was {:?}", event); + }; + assert_eq!(header, expected_header); + assert_eq!(sbor_depth, expected_depth); + assert_eq!(start_offset, expected_start_offset); + assert_eq!(end_offset, expected_end_offset); + } + + pub fn next_event_is_container_end( + traverser: &mut BasicTraverser, + expected_header: ContainerHeader, + expected_depth: usize, + expected_start_offset: usize, + expected_end_offset: usize, + ) { + let event = traverser.next_event(); + let sbor_depth = event.location.ancestor_path.len() + 1; + let LocatedTraversalEvent { + event: TraversalEvent::ContainerEnd(header), + location: + Location { + start_offset, + end_offset, + .. + }, + } = event + else { + panic!("Invalid event - expected ContainerEnd, was {:?}", event); + }; + assert_eq!(header, expected_header); + assert_eq!(sbor_depth, expected_depth); + assert_eq!(start_offset, expected_start_offset); + assert_eq!(end_offset, expected_end_offset); + } + + pub fn next_event_is_terminal_value<'de>( + traverser: &mut BasicTraverser<'de>, + expected_value: TerminalValueRef<'de, NoCustomTraversal>, + expected_child_depth: usize, + expected_start_offset: usize, + expected_end_offset: usize, + ) { + let event = traverser.next_event(); + let sbor_depth = event.location.ancestor_path.len() + 1; + let LocatedTraversalEvent { + event: TraversalEvent::TerminalValue(value), + location: + Location { + start_offset, + end_offset, + .. + }, + } = event + else { + panic!("Invalid event - expected TerminalValue, was {:?}", event); + }; + assert_eq!(value, expected_value); + assert_eq!(sbor_depth, expected_child_depth); + assert_eq!(start_offset, expected_start_offset); + assert_eq!(end_offset, expected_end_offset); + } + + pub fn next_event_is_terminal_value_slice<'de>( + traverser: &mut BasicTraverser<'de>, + expected_value_batch: TerminalValueBatchRef<'de>, + expected_child_depth: usize, + expected_start_offset: usize, + expected_end_offset: usize, + ) { + let event = traverser.next_event(); + let sbor_depth = event.location.ancestor_path.len() + 1; + let LocatedTraversalEvent { + event: TraversalEvent::TerminalValueBatch(value_batch), + location: + Location { + start_offset, + end_offset, + .. + }, + } = event + else { + panic!( + "Invalid event - expected TerminalValueBatch, was {:?}", + event + ); + }; + assert_eq!(value_batch, expected_value_batch); + assert_eq!(sbor_depth, expected_child_depth); + assert_eq!(start_offset, expected_start_offset); + assert_eq!(end_offset, expected_end_offset); + } + + pub fn next_event_is_end( + traverser: &mut BasicTraverser, + expected_start_offset: usize, + expected_end_offset: usize, + ) { + let event = traverser.next_event(); + let LocatedTraversalEvent { + event: TraversalEvent::End, + location: + Location { + start_offset, + end_offset, + .. + }, + } = event + else { + panic!("Invalid event - expected End, was {:?}", event); + }; + assert_eq!(start_offset, expected_start_offset); + assert_eq!(end_offset, expected_end_offset); + assert!(event.location.ancestor_path.is_empty()); + } +} diff --git a/sbor/src/traversal/untyped/mod.rs b/sbor/src/traversal/untyped/mod.rs index 6acbc00de4b..28610ec0a95 100644 --- a/sbor/src/traversal/untyped/mod.rs +++ b/sbor/src/traversal/untyped/mod.rs @@ -1,5 +1,30 @@ +use crate::internal_prelude::*; + +mod event_stream_traverser; mod events; -mod traverser; +mod traversal_traits; +mod untyped_traverser; +mod utility_visitors; +pub use event_stream_traverser::*; pub use events::*; -pub use traverser::*; +pub use traversal_traits::*; +pub use untyped_traverser::*; +pub use utility_visitors::*; + +/// Returns the length of the value at the start of the partial payload. +pub fn calculate_value_tree_body_byte_length<'de, 's, E: CustomExtension>( + partial_payload: &'de [u8], + value_kind: ValueKind, + current_depth: usize, + depth_limit: usize, +) -> Result { + let mut traverser = UntypedTraverser::::new( + partial_payload, + UntypedTraverserConfig { + max_depth: depth_limit - current_depth, + check_exact_end: false, + }, + ); + traverser.run_from_start(ExpectedStart::ValueBody(value_kind), &mut ValidatingVisitor) +} diff --git a/sbor/src/traversal/untyped/traversal_traits.rs b/sbor/src/traversal/untyped/traversal_traits.rs new file mode 100644 index 00000000000..b51a2c51d12 --- /dev/null +++ b/sbor/src/traversal/untyped/traversal_traits.rs @@ -0,0 +1,111 @@ +use crate::internal_prelude::*; +use core::ops::ControlFlow; + +pub trait CustomTraversal: Copy + Debug + Clone + PartialEq + Eq + 'static { + type CustomValueKind: CustomValueKind; + type CustomTerminalValueRef<'de>: CustomTerminalValueRef< + CustomValueKind = Self::CustomValueKind, + >; + + fn read_custom_value_body<'de, R>( + custom_value_kind: Self::CustomValueKind, + reader: &mut R, + ) -> Result, DecodeError> + where + R: BorrowingDecoder<'de, Self::CustomValueKind>; +} + +pub trait CustomTerminalValueRef: Debug + Clone + PartialEq + Eq { + type CustomValueKind: CustomValueKind; + + fn custom_value_kind(&self) -> Self::CustomValueKind; +} + +// We add this allow so that the placeholder names don't have to start with underscores +#[allow(unused_variables)] +pub trait UntypedPayloadVisitor<'de, T: CustomTraversal> { + type Output<'t>; + + #[inline] + #[must_use] + fn on_container_start<'t>( + &mut self, + details: OnContainerStart<'t, T>, + ) -> ControlFlow> { + ControlFlow::Continue(()) + } + + #[inline] + #[must_use] + fn on_terminal_value<'t>( + &mut self, + details: OnTerminalValue<'t, 'de, T>, + ) -> ControlFlow> { + ControlFlow::Continue(()) + } + + #[inline] + #[must_use] + fn on_terminal_value_batch<'t>( + &mut self, + details: OnTerminalValueBatch<'t, 'de, T>, + ) -> ControlFlow> { + ControlFlow::Continue(()) + } + + #[inline] + #[must_use] + fn on_container_end<'t>( + &mut self, + details: OnContainerEnd<'t, T>, + ) -> ControlFlow> { + ControlFlow::Continue(()) + } + + #[must_use] + fn on_error<'t>(&mut self, details: OnError<'t, T>) -> Self::Output<'t>; + + #[must_use] + fn on_traversal_end<'t>(&mut self, details: OnTraversalEnd<'t, T>) -> Self::Output<'t>; +} + +pub struct OnContainerStart<'t, T: CustomTraversal> { + pub header: ContainerHeader, + pub location: Location<'t, T>, + /// If requesting to break, the traversal can be continued with this action. + /// This will be optimized out if the visitor doesn't use it. + pub resume_action: NextAction, +} + +pub struct OnTerminalValue<'t, 'de, T: CustomTraversal> { + pub value: TerminalValueRef<'de, T>, + pub location: Location<'t, T>, + /// If requesting to break, the traversal can be continued with this action. + /// This will be optimized out if the visitor doesn't use it. + pub resume_action: NextAction, +} + +pub struct OnTerminalValueBatch<'t, 'de, T: CustomTraversal> { + pub value_batch: TerminalValueBatchRef<'de>, + pub location: Location<'t, T>, + /// If requesting to break, the traversal can be continued with this action. + /// This will be optimized out if the visitor doesn't require it. + pub resume_action: NextAction, +} + +pub struct OnContainerEnd<'t, T: CustomTraversal> { + pub header: ContainerHeader, + pub location: Location<'t, T>, + /// If requesting to break, the traversal can be continued with this action. + /// This will be optimized out if the visitor doesn't require it. + pub resume_action: NextAction, +} + +pub struct OnTraversalEnd<'t, T: CustomTraversal> { + pub location: Location<'t, T>, +} + +pub struct OnError<'t, T: CustomTraversal> { + pub error: DecodeError, + pub location: Location<'t, T>, +} diff --git a/sbor/src/traversal/untyped/traverser.rs b/sbor/src/traversal/untyped/untyped_traverser.rs similarity index 56% rename from sbor/src/traversal/untyped/traverser.rs rename to sbor/src/traversal/untyped/untyped_traverser.rs index 08b60027b15..85330f2a747 100644 --- a/sbor/src/traversal/untyped/traverser.rs +++ b/sbor/src/traversal/untyped/untyped_traverser.rs @@ -1,3 +1,5 @@ +use core::ops::ControlFlow; + use super::*; use crate::decoder::BorrowingDecoder; use crate::rust::prelude::*; @@ -5,49 +7,11 @@ use crate::rust::str; use crate::value_kind::*; use crate::*; -/// Returns the length of the value at the start of the partial payload. -pub fn calculate_value_tree_body_byte_length<'de, 's, E: CustomExtension>( - partial_payload: &'de [u8], - value_kind: ValueKind, - current_depth: usize, - depth_limit: usize, -) -> Result { - let mut traverser = VecTraverser::::new( - partial_payload, - ExpectedStart::ValueBody(value_kind), - VecTraverserConfig { - max_depth: depth_limit - current_depth, - check_exact_end: false, - }, - ); - loop { - let next_event = traverser.next_event(); - match next_event.event { - TraversalEvent::End => return Ok(next_event.location.end_offset), - TraversalEvent::DecodeError(decode_error) => return Err(decode_error), - _ => {} - } - } -} - -pub trait CustomTraversal: Copy + Debug + Clone + PartialEq + Eq { - type CustomValueKind: CustomValueKind; - type CustomTerminalValueRef<'de>: CustomTerminalValueRef< - CustomValueKind = Self::CustomValueKind, - >; - - fn read_custom_value_body<'de, R>( - custom_value_kind: Self::CustomValueKind, - reader: &mut R, - ) -> Result, DecodeError> - where - R: BorrowingDecoder<'de, Self::CustomValueKind>; -} - -pub trait CustomTerminalValueRef: Debug + Clone + PartialEq + Eq { - type CustomValueKind: CustomValueKind; - - fn custom_value_kind(&self) -> Self::CustomValueKind; +/// Designed for streamed decoding of a payload or single encoded value (tree). +pub struct UntypedTraverser<'de, T: CustomTraversal> { + decoder: VecDecoder<'de, T::CustomValueKind>, + ancestor_path: Vec>, + config: UntypedTraverserConfig, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -65,24 +29,15 @@ pub struct AncestorState { impl AncestorState { #[inline] - fn get_implicit_value_kind_of_current_child(&self) -> Option> { + pub fn get_implicit_value_kind_of_current_child( + &self, + ) -> Option> { self.container_header .get_implicit_child_value_kind(self.current_child_index) } } -/// The `VecTraverser` is for streamed decoding of a payload or single encoded value (tree). -/// It turns payload decoding into a pull-based event stream. -/// -/// The caller is responsible for stopping calling `next_event` after an Error or End event. -pub struct VecTraverser<'de, T: CustomTraversal> { - decoder: VecDecoder<'de, T::CustomValueKind>, - ancestor_path: Vec>, - next_action: NextAction, - config: VecTraverserConfig, -} - -pub struct VecTraverserConfig { +pub struct UntypedTraverserConfig { pub max_depth: usize, pub check_exact_end: bool, } @@ -103,10 +58,6 @@ pub enum NextAction { /// The state which is put into after entering parent, and /// the default state to return to from below ReadNextChildOrExitContainer, - Errored, - Ended, - /// Impossible to observe this value - InProgressPlaceholder, } #[derive(Debug, Clone, Copy)] @@ -116,73 +67,89 @@ pub enum ExpectedStart { ValueBody(ValueKind), } -impl<'de, T: CustomTraversal> VecTraverser<'de, T> { - pub fn new( - input: &'de [u8], - expected_start: ExpectedStart, - config: VecTraverserConfig, - ) -> Self { +impl ExpectedStart { + pub fn into_starting_action>(self) -> NextAction { + match self { + ExpectedStart::PayloadPrefix(prefix) => NextAction::ReadPrefix { + expected_prefix: prefix, + }, + ExpectedStart::Value => NextAction::ReadRootValue, + ExpectedStart::ValueBody(value_kind) => NextAction::ReadRootValueBody { + implicit_value_kind: value_kind, + }, + } + } +} + +impl<'de, T: CustomTraversal> UntypedTraverser<'de, T> { + pub fn new(input: &'de [u8], config: UntypedTraverserConfig) -> Self { Self { - // Note that the VecTraverser needs to be very low level for performance, + // Note that the VecTraverserV2 needs to be very low level for performance, // so purposefully doesn't use the depth tracking in the decoder itself. // But we set a max depth anyway, for safety. decoder: VecDecoder::new(input, config.max_depth), ancestor_path: Vec::with_capacity(config.max_depth), - next_action: match expected_start { - ExpectedStart::PayloadPrefix(prefix) => NextAction::ReadPrefix { - expected_prefix: prefix, - }, - ExpectedStart::Value => NextAction::ReadRootValue, - ExpectedStart::ValueBody(value_kind) => NextAction::ReadRootValueBody { - implicit_value_kind: value_kind, - }, - }, config, } } - pub fn next_event<'t>(&'t mut self) -> LocatedTraversalEvent<'t, 'de, T> { - let (event, next_action) = Self::step( - core::mem::replace(&mut self.next_action, NextAction::InProgressPlaceholder), - &self.config, - &mut self.decoder, - &mut self.ancestor_path, - ); - self.next_action = next_action; - event + pub fn run_from_start<'t, V: UntypedPayloadVisitor<'de, T>>( + &'t mut self, + expected_start: ExpectedStart, + visitor: &mut V, + ) -> V::Output<'t> { + self.continue_traversal_from(expected_start.into_starting_action(), visitor) + } + + /// # Expected behaviour + /// Start action should either be an action from ExpectedStart, or a `resume_action` returned + /// in a previous event. + pub fn continue_traversal_from<'t, V: UntypedPayloadVisitor<'de, T>>( + &'t mut self, + start_action: NextAction, + visitor: &mut V, + ) -> V::Output<'t> { + let mut action = start_action; + loop { + // SAFETY: Work around the current borrow checker, which is sound as per this thread: + // https://users.rust-lang.org/t/mutable-borrow-in-loop-borrow-checker-query/118081/3 + // Unsafe syntax borrowed from here: https://docs.rs/polonius-the-crab/latest/polonius_the_crab/ + // Can remove this once the polonius borrow checker hits stable + let ancester_path = unsafe { &mut *(&mut self.ancestor_path as *mut _) }; + action = match Self::step( + action, + &self.config, + &mut self.decoder, + ancester_path, + visitor, + ) { + ControlFlow::Continue(action) => action, + ControlFlow::Break(output) => return output, + }; + } } #[inline] - fn step<'t, 'd>( + fn step<'t, V: UntypedPayloadVisitor<'de, T>>( action: NextAction, - config: &VecTraverserConfig, - decoder: &'d mut VecDecoder<'de, T::CustomValueKind>, + config: &UntypedTraverserConfig, + decoder: &mut VecDecoder<'de, T::CustomValueKind>, ancestor_path: &'t mut Vec>, - ) -> (LocatedTraversalEvent<'t, 'de, T>, NextAction) { + visitor: &mut V, + ) -> ControlFlow, NextAction> { match action { NextAction::ReadPrefix { expected_prefix } => { - // The reading of the prefix has no associated event, so we perform the prefix check first, - // and then proceed to read the root value if it succeeds. - let start_offset = decoder.get_offset(); - match decoder.read_and_check_payload_prefix(expected_prefix) { - Ok(()) => { - // Prefix read successfully. Now read root value. - ActionHandler::new_from_current_offset(ancestor_path, decoder) - .read_value(None) - } - Err(error) => { - ActionHandler::new_with_fixed_offset(ancestor_path, decoder, start_offset) - .complete_with_error(error) - } - } + Locator::with(decoder.get_offset(), ancestor_path, decoder) + .read_and_check_payload_prefix::(expected_prefix, visitor) } NextAction::ReadRootValue => { - ActionHandler::new_from_current_offset(ancestor_path, decoder).read_value(None) + Locator::with(decoder.get_offset(), ancestor_path, decoder) + .read_value(None, visitor) } NextAction::ReadRootValueBody { implicit_value_kind, - } => ActionHandler::new_from_current_offset(ancestor_path, decoder) - .read_value(Some(implicit_value_kind)), + } => Locator::with(decoder.get_offset(), ancestor_path, decoder) + .read_value(Some(implicit_value_kind), visitor), NextAction::ReadContainerContentStart { container_header, container_start_offset, @@ -191,53 +158,55 @@ impl<'de, T: CustomTraversal> VecTraverser<'de, T> { if container_child_size == 0 { // If the container has no children, we immediately container end without ever bothering // adding it as an ancestor. - return ActionHandler::new_with_fixed_offset( - ancestor_path, - decoder, + Locator::with(container_start_offset, ancestor_path, decoder) + .complete_container_end(container_header, visitor) + } else { + // Add ancestor before checking for max depth so that the ancestor stack is + // correct if the depth check returns an error + ancestor_path.push(AncestorState { + container_header, container_start_offset, - ) - .complete_container_end(container_header); - } - - // Add ancestor before checking for max depth so that the ancestor stack is - // correct if the depth check returns an error - ancestor_path.push(AncestorState { - container_header, - container_start_offset, - current_child_index: 0, - }); - // We know we're about to read a child at depth ancestor_path.len() + 1 - so - // it's an error if ancestor_path.len() >= config.max_depth. - // (We avoid the +1 so that we don't need to worry about overflow). - if ancestor_path.len() >= config.max_depth { - return ActionHandler::new_from_current_offset(ancestor_path, decoder) - .complete_with_error(DecodeError::MaxDepthExceeded(config.max_depth)); - } + current_child_index: 0, + }); + + // We know we're about to read a child at depth ancestor_path.len() + 1 - so + // it's an error if ancestor_path.len() >= config.max_depth. + // (We avoid the +1 so that we don't need to worry about overflow). + if ancestor_path.len() >= config.max_depth { + let error_output = + Locator::with(decoder.get_offset(), ancestor_path, decoder) + .handle_error( + DecodeError::MaxDepthExceeded(config.max_depth), + visitor, + ); + return ControlFlow::Break(error_output); + } - let parent = ancestor_path.last_mut().unwrap(); - let parent_container = &parent.container_header; - let is_byte_array = matches!( - parent_container, - ContainerHeader::Array(ArrayHeader { - element_value_kind: ValueKind::U8, - .. - }) - ); - // If it's a byte array, we do a batch-read optimisation - if is_byte_array { - // We know this is >= 1 from the above check - let array_length = container_child_size; - let max_index_which_would_be_read = array_length - 1; - // Set current child index before we read so that if we get an error on read - // then it comes through at the max child index we attempted to read. - parent.current_child_index = max_index_which_would_be_read; - ActionHandler::new_from_current_offset(ancestor_path, decoder) - .read_byte_array(array_length) - } else { - // NOTE: parent.current_child_index is already 0, so no need to change it - let implicit_value_kind = parent.get_implicit_value_kind_of_current_child(); - ActionHandler::new_from_current_offset(ancestor_path, decoder) - .read_value(implicit_value_kind) + let parent = ancestor_path.last_mut().unwrap(); + let parent_container = &parent.container_header; + let is_byte_array = matches!( + parent_container, + ContainerHeader::Array(ArrayHeader { + element_value_kind: ValueKind::U8, + .. + }) + ); + // If it's a byte array, we do a batch-read optimisation + if is_byte_array { + // We know this is >= 1 from the above check + let array_length = container_child_size; + let max_index_which_would_be_read = array_length - 1; + // Set current child index before we read so that if we get an error on read + // then it comes through at the max child index we attempted to read. + parent.current_child_index = max_index_which_would_be_read; + Locator::with(decoder.get_offset(), ancestor_path, decoder) + .read_byte_array(array_length, visitor) + } else { + // NOTE: parent.current_child_index is already 0, so no need to change it + let implicit_value_kind = parent.get_implicit_value_kind_of_current_child(); + Locator::with(decoder.get_offset(), ancestor_path, decoder) + .read_value(implicit_value_kind, visitor) + } } } NextAction::ReadNextChildOrExitContainer => { @@ -255,68 +224,52 @@ impl<'de, T: CustomTraversal> VecTraverser<'de, T> { .. } = ancestor_path.pop().expect("Parent has just been read"); - ActionHandler::new_with_fixed_offset( - ancestor_path, - decoder, - container_start_offset, - ) - .complete_container_end(container_header) + Locator::with(container_start_offset, ancestor_path, decoder) + .complete_container_end(container_header, visitor) } else { parent.current_child_index = next_child_index; let implicit_value_kind = parent.get_implicit_value_kind_of_current_child(); - ActionHandler::new_from_current_offset(ancestor_path, decoder) - .read_value(implicit_value_kind) + Locator::with(decoder.get_offset(), ancestor_path, decoder) + .read_value(implicit_value_kind, visitor) } } None => { // We are due to read another element and exit but have no parent // This is because we have finished reading the `root` value. - // Therefore we call `end`. - ActionHandler::new_from_current_offset(ancestor_path, decoder).end(config) + let output = Locator::with(decoder.get_offset(), ancestor_path, decoder) + .handle_traversal_end(config.check_exact_end, visitor); + return ControlFlow::Break(output); } } } - NextAction::Errored => { - panic!("It is unsupported to call `next_event` on a traverser which has returned an error.") - } - NextAction::Ended => { - panic!("It is unsupported to call `next_event` on a traverser which has already emitted an end event.") - } - NextAction::InProgressPlaceholder => { - unreachable!("It is not possible to observe this value - it is a placeholder for rust memory safety.") - } } } } -macro_rules! handle_error { - ($action_handler: expr, $result: expr$(,)?) => {{ - match $result { - Ok(value) => value, - Err(error) => { - return $action_handler.complete_with_error(error); - } - } +macro_rules! handle_result { + ($self: expr, $visitor: expr, $result: expr$(,)?) => {{ + let result = $result; + $self.handle_result(result, $visitor)? }}; } /// This is just an encapsulation to improve code quality by: /// * Removing code duplication by capturing the ancestor_path/decoder/start_offset in one place /// * Ensuring code correctness by fixing the ancestor path -struct ActionHandler<'t, 'd, 'de, T: CustomTraversal> { +struct Locator<'t, 'd, 'de, T: CustomTraversal> { ancestor_path: &'t [AncestorState], decoder: &'d mut VecDecoder<'de, T::CustomValueKind>, start_offset: usize, } -impl<'t, 'd, 'de, T: CustomTraversal> ActionHandler<'t, 'd, 'de, T> { +impl<'t, 'd, 'de, T: CustomTraversal> Locator<'t, 'd, 'de, T> { #[inline] - fn new_from_current_offset( + fn with( + start_offset: usize, ancestor_path: &'t [AncestorState], decoder: &'d mut VecDecoder<'de, T::CustomValueKind>, ) -> Self { - let start_offset = decoder.get_offset(); Self { ancestor_path, decoder, @@ -324,192 +277,266 @@ impl<'t, 'd, 'de, T: CustomTraversal> ActionHandler<'t, 'd, 'de, T> { } } + #[must_use] #[inline] - fn new_with_fixed_offset( - ancestor_path: &'t [AncestorState], - decoder: &'d mut VecDecoder<'de, T::CustomValueKind>, - start_offset: usize, - ) -> Self { - Self { - ancestor_path, - decoder, - start_offset, - } + fn read_and_check_payload_prefix>( + self, + expected_prefix: u8, + visitor: &mut V, + ) -> ControlFlow, NextAction> { + handle_result!( + self, + visitor, + self.decoder.read_and_check_payload_prefix(expected_prefix) + ); + ControlFlow::Continue(NextAction::ReadRootValue) } #[inline] - fn read_value( + #[must_use] + fn read_value>( self, implicit_value_kind: Option>, - ) -> (LocatedTraversalEvent<'t, 'de, T>, NextAction) { + visitor: &mut V, + ) -> ControlFlow, NextAction> { let value_kind = match implicit_value_kind { Some(value_kind) => value_kind, - None => handle_error!(self, self.decoder.read_value_kind()), + None => handle_result!(self, visitor, self.decoder.read_value_kind()), }; - self.read_value_body(value_kind) + self.read_value_body(value_kind, visitor) } #[inline] - fn read_byte_array( + #[must_use] + fn read_byte_array>( self, array_length: usize, - ) -> (LocatedTraversalEvent<'t, 'de, T>, NextAction) { - let bytes = handle_error!(self, self.decoder.read_slice_from_payload(array_length)); - self.complete( - TraversalEvent::TerminalValueBatch(TerminalValueBatchRef::U8(bytes)), - // This is the correct action to ensure we exit the container on the next step - NextAction::ReadNextChildOrExitContainer, - ) - } - - #[inline] - fn end( - self, - config: &VecTraverserConfig, - ) -> (LocatedTraversalEvent<'t, 'de, T>, NextAction) { - if config.check_exact_end { - handle_error!(self, self.decoder.check_end()); - } - self.complete(TraversalEvent::End, NextAction::Ended) + visitor: &mut V, + ) -> ControlFlow, NextAction> { + let bytes = handle_result!( + self, + visitor, + self.decoder.read_slice_from_payload(array_length) + ); + self.complete_terminal_value_batch(TerminalValueBatchRef::U8(bytes), visitor) } #[inline] - fn read_value_body( + #[must_use] + fn read_value_body>( self, value_kind: ValueKind, - ) -> (LocatedTraversalEvent<'t, 'de, T>, NextAction) { + visitor: &mut V, + ) -> ControlFlow, NextAction> { match value_kind { - ValueKind::Bool => self.read_terminal_value(value_kind, TerminalValueRef::Bool), - ValueKind::I8 => self.read_terminal_value(value_kind, TerminalValueRef::I8), - ValueKind::I16 => self.read_terminal_value(value_kind, TerminalValueRef::I16), - ValueKind::I32 => self.read_terminal_value(value_kind, TerminalValueRef::I32), - ValueKind::I64 => self.read_terminal_value(value_kind, TerminalValueRef::I64), - ValueKind::I128 => self.read_terminal_value(value_kind, TerminalValueRef::I128), - ValueKind::U8 => self.read_terminal_value(value_kind, TerminalValueRef::U8), - ValueKind::U16 => self.read_terminal_value(value_kind, TerminalValueRef::U16), - ValueKind::U32 => self.read_terminal_value(value_kind, TerminalValueRef::U32), - ValueKind::U64 => self.read_terminal_value(value_kind, TerminalValueRef::U64), - ValueKind::U128 => self.read_terminal_value(value_kind, TerminalValueRef::U128), + ValueKind::Bool => self.read_basic_value(value_kind, TerminalValueRef::Bool, visitor), + ValueKind::I8 => self.read_basic_value(value_kind, TerminalValueRef::I8, visitor), + ValueKind::I16 => self.read_basic_value(value_kind, TerminalValueRef::I16, visitor), + ValueKind::I32 => self.read_basic_value(value_kind, TerminalValueRef::I32, visitor), + ValueKind::I64 => self.read_basic_value(value_kind, TerminalValueRef::I64, visitor), + ValueKind::I128 => self.read_basic_value(value_kind, TerminalValueRef::I128, visitor), + ValueKind::U8 => self.read_basic_value(value_kind, TerminalValueRef::U8, visitor), + ValueKind::U16 => self.read_basic_value(value_kind, TerminalValueRef::U16, visitor), + ValueKind::U32 => self.read_basic_value(value_kind, TerminalValueRef::U32, visitor), + ValueKind::U64 => self.read_basic_value(value_kind, TerminalValueRef::U64, visitor), + ValueKind::U128 => self.read_basic_value(value_kind, TerminalValueRef::U128, visitor), ValueKind::String => { - let length = handle_error!(self, self.decoder.read_size()); - let bytes = handle_error!(self, self.decoder.read_slice_from_payload(length)); - let string_body = handle_error!( - self, - str::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8) - ); - self.complete( - TraversalEvent::TerminalValue(TerminalValueRef::String(string_body)), - NextAction::ReadNextChildOrExitContainer, - ) + let length = handle_result!(self, visitor, self.decoder.read_size()); + let bytes = + handle_result!(self, visitor, self.decoder.read_slice_from_payload(length)); + let string_decode_result = + str::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8); + let string_body = handle_result!(self, visitor, string_decode_result); + self.complete_terminal_value(TerminalValueRef::String(string_body), visitor) } ValueKind::Array => { - let element_value_kind = handle_error!(self, self.decoder.read_value_kind()); - let length = handle_error!(self, self.decoder.read_size()); - self.complete_container_start(ContainerHeader::Array(ArrayHeader { - element_value_kind, - length, - })) + let element_value_kind = + handle_result!(self, visitor, self.decoder.read_value_kind()); + let length = handle_result!(self, visitor, self.decoder.read_size()); + self.complete_container_start( + ContainerHeader::Array(ArrayHeader { + element_value_kind, + length, + }), + visitor, + ) } ValueKind::Map => { - let key_value_kind = handle_error!(self, self.decoder.read_value_kind()); - let value_value_kind = handle_error!(self, self.decoder.read_value_kind()); - let length = handle_error!(self, self.decoder.read_size()); - self.complete_container_start(ContainerHeader::Map(MapHeader { - key_value_kind, - value_value_kind, - length, - })) + let key_value_kind = handle_result!(self, visitor, self.decoder.read_value_kind()); + let value_value_kind = + handle_result!(self, visitor, self.decoder.read_value_kind()); + let length = handle_result!(self, visitor, self.decoder.read_size()); + self.complete_container_start( + ContainerHeader::Map(MapHeader { + key_value_kind, + value_value_kind, + length, + }), + visitor, + ) } ValueKind::Enum => { - let variant = handle_error!(self, self.decoder.read_byte()); - let length = handle_error!(self, self.decoder.read_size()); - self.complete_container_start(ContainerHeader::EnumVariant(EnumVariantHeader { - variant, - length, - })) + let variant = handle_result!(self, visitor, self.decoder.read_byte()); + let length = handle_result!(self, visitor, self.decoder.read_size()); + self.complete_container_start( + ContainerHeader::EnumVariant(EnumVariantHeader { variant, length }), + visitor, + ) } ValueKind::Tuple => { - let length = handle_error!(self, self.decoder.read_size()); - self.complete_container_start(ContainerHeader::Tuple(TupleHeader { length })) + let length = handle_result!(self, visitor, self.decoder.read_size()); + self.complete_container_start( + ContainerHeader::Tuple(TupleHeader { length }), + visitor, + ) } ValueKind::Custom(custom_value_kind) => { - let custom_value_ref = handle_error!( + let custom_value_ref = handle_result!( self, + visitor, T::read_custom_value_body(custom_value_kind, self.decoder) ); - self.complete( - TraversalEvent::TerminalValue(TerminalValueRef::Custom(custom_value_ref)), - NextAction::ReadNextChildOrExitContainer, - ) + self.complete_terminal_value(TerminalValueRef::Custom(custom_value_ref), visitor) } } } #[inline] - fn read_terminal_value>>( + #[must_use] + fn read_basic_value< + X: Decode>, + V: UntypedPayloadVisitor<'de, T>, + >( self, value_kind: ValueKind, - value_ref_constructor: impl Fn(V) -> TerminalValueRef<'de, T>, - ) -> (LocatedTraversalEvent<'t, 'de, T>, NextAction) { - match V::decode_body_with_value_kind(self.decoder, value_kind) { - Ok(value) => self.complete( - TraversalEvent::TerminalValue(value_ref_constructor(value)), - NextAction::ReadNextChildOrExitContainer, - ), - Err(error) => self.complete_with_error(error), - } + value_ref_constructor: impl Fn(X) -> TerminalValueRef<'de, T>, + visitor: &mut V, + ) -> ControlFlow, NextAction> { + let value = handle_result!( + self, + visitor, + X::decode_body_with_value_kind(self.decoder, value_kind) + ); + self.complete_terminal_value(value_ref_constructor(value), visitor) } #[inline] - fn complete_container_start( + #[must_use] + fn complete_terminal_value>( + self, + value_ref: TerminalValueRef<'de, T>, + visitor: &mut V, + ) -> ControlFlow, NextAction> { + let next_action = NextAction::ReadNextChildOrExitContainer; + visitor.on_terminal_value(OnTerminalValue { + location: self.location(), + value: value_ref, + resume_action: next_action, + })?; + ControlFlow::Continue(next_action) + } + + #[inline] + #[must_use] + fn complete_terminal_value_batch>( + self, + value_batch_ref: TerminalValueBatchRef<'de>, + visitor: &mut V, + ) -> ControlFlow, NextAction> { + let next_action = NextAction::ReadNextChildOrExitContainer; + visitor.on_terminal_value_batch(OnTerminalValueBatch { + location: self.location(), + value_batch: value_batch_ref, + resume_action: next_action, + })?; + ControlFlow::Continue(next_action) + } + + #[inline] + #[must_use] + fn complete_container_start>( self, container_header: ContainerHeader, - ) -> (LocatedTraversalEvent<'t, 'de, T>, NextAction) { + visitor: &mut V, + ) -> ControlFlow, NextAction> { let next_action = NextAction::ReadContainerContentStart { container_header: container_header.clone(), container_start_offset: self.start_offset, }; - self.complete( - TraversalEvent::ContainerStart(container_header), - next_action, - ) + visitor.on_container_start(OnContainerStart { + location: self.location(), + header: container_header, + resume_action: next_action, + })?; + ControlFlow::Continue(next_action) } #[inline] - fn complete_container_end( + #[must_use] + fn complete_container_end>( self, container_header: ContainerHeader, - ) -> (LocatedTraversalEvent<'t, 'de, T>, NextAction) { - self.complete( - TraversalEvent::ContainerEnd(container_header), - // Continue interating the parent - NextAction::ReadNextChildOrExitContainer, - ) + visitor: &mut V, + ) -> ControlFlow, NextAction> { + let next_action = NextAction::ReadNextChildOrExitContainer; + visitor.on_container_end(OnContainerEnd { + location: self.location(), + header: container_header, + resume_action: next_action, + })?; + ControlFlow::Continue(next_action) + } + + #[inline] + #[must_use] + fn handle_result, X>( + &self, + result: Result, + visitor: &mut V, + ) -> ControlFlow, X> { + match result { + Ok(value) => ControlFlow::Continue(value), + Err(error) => ControlFlow::Break(self.handle_error(error, visitor)), + } } #[inline] - fn complete_with_error( + #[must_use] + fn handle_traversal_end>( self, + check_end: bool, + visitor: &mut V, + ) -> V::Output<'t> { + if check_end { + if let Err(error) = self.decoder.check_end() { + return self.handle_error(error, visitor); + } + } + visitor.on_traversal_end(OnTraversalEnd { + location: self.location(), + }) + } + + #[inline] + #[must_use] + fn handle_error>( + &self, error: DecodeError, - ) -> (LocatedTraversalEvent<'t, 'de, T>, NextAction) { - self.complete(TraversalEvent::DecodeError(error), NextAction::Errored) + visitor: &mut V, + ) -> V::Output<'t> { + visitor.on_error(OnError { + error, + location: self.location(), + }) } #[inline] - fn complete( - self, - traversal_event: TraversalEvent<'de, T>, - next_action: NextAction, - ) -> (LocatedTraversalEvent<'t, 'de, T>, NextAction) { - let located_event = LocatedTraversalEvent { - event: traversal_event, - location: Location { - start_offset: self.start_offset, - end_offset: self.decoder.get_offset(), - ancestor_path: self.ancestor_path, - }, - }; - (located_event, next_action) + fn location(&self) -> Location<'t, T> { + Location { + start_offset: self.start_offset, + end_offset: self.decoder.get_offset(), + ancestor_path: self.ancestor_path, + } } } diff --git a/sbor/src/traversal/untyped/utility_visitors.rs b/sbor/src/traversal/untyped/utility_visitors.rs new file mode 100644 index 00000000000..4e104d4e61d --- /dev/null +++ b/sbor/src/traversal/untyped/utility_visitors.rs @@ -0,0 +1,80 @@ +use crate::internal_prelude::*; + +pub struct ValidatingVisitor; + +impl<'de, T: CustomTraversal + 'static> UntypedPayloadVisitor<'de, T> for ValidatingVisitor { + type Output<'t> = Result; + + fn on_error<'t>(&mut self, details: OnError<'t, T>) -> Self::Output<'t> { + Err(details.error) + } + + fn on_traversal_end<'t>(&mut self, details: OnTraversalEnd<'t, T>) -> Self::Output<'t> { + Ok(details.location.end_offset) + } +} + +pub struct EventStreamVisitor<'de, T: CustomTraversal> { + pub next_action: SuspendableNextAction, + pub next_event: Option>, +} + +pub enum SuspendableNextAction { + Action(NextAction), + Errored, + Ended, +} + +impl<'de, T: CustomTraversal + 'static> UntypedPayloadVisitor<'de, T> + for EventStreamVisitor<'de, T> +{ + type Output<'t> = Location<'t, T>; + + fn on_container_start<'t>( + &mut self, + details: OnContainerStart<'t, T>, + ) -> ControlFlow> { + self.next_action = SuspendableNextAction::Action(details.resume_action); + self.next_event = Some(TraversalEvent::ContainerStart(details.header)); + ControlFlow::Break(details.location) + } + + fn on_terminal_value<'t>( + &mut self, + details: OnTerminalValue<'t, 'de, T>, + ) -> ControlFlow> { + self.next_action = SuspendableNextAction::Action(details.resume_action); + self.next_event = Some(TraversalEvent::TerminalValue(details.value)); + ControlFlow::Break(details.location) + } + + fn on_terminal_value_batch<'t>( + &mut self, + details: OnTerminalValueBatch<'t, 'de, T>, + ) -> ControlFlow> { + self.next_action = SuspendableNextAction::Action(details.resume_action); + self.next_event = Some(TraversalEvent::TerminalValueBatch(details.value_batch)); + ControlFlow::Break(details.location) + } + + fn on_container_end<'t>( + &mut self, + details: OnContainerEnd<'t, T>, + ) -> ControlFlow> { + self.next_action = SuspendableNextAction::Action(details.resume_action); + self.next_event = Some(TraversalEvent::ContainerEnd(details.header)); + ControlFlow::Break(details.location) + } + + fn on_error<'t>(&mut self, details: OnError<'t, T>) -> Self::Output<'t> { + self.next_action = SuspendableNextAction::Errored; + self.next_event = Some(TraversalEvent::DecodeError(details.error)); + details.location + } + + fn on_traversal_end<'t>(&mut self, details: OnTraversalEnd<'t, T>) -> Self::Output<'t> { + self.next_action = SuspendableNextAction::Ended; + self.next_event = Some(TraversalEvent::End); + details.location + } +} diff --git a/sbor/src/vec_traits.rs b/sbor/src/vec_traits.rs index 47559947b58..8b1ca011acf 100644 --- a/sbor/src/vec_traits.rs +++ b/sbor/src/vec_traits.rs @@ -1,7 +1,4 @@ -use crate::{ - internal_prelude::*, validate_payload_against_schema, CustomExtension, CustomSchema, - Decoder as _, Describe, Encoder as _, ValidatableCustomExtension, VecDecoder, VecEncoder, -}; +use crate::internal_prelude::*; pub trait VecEncode: for<'a> Encode> {} impl Encode> + ?Sized> VecEncode for T {} From ae4c431fa56f6b9c438d590f786040272b48fa3f Mon Sep 17 00:00:00 2001 From: David Edey Date: Wed, 25 Sep 2024 22:56:53 +0100 Subject: [PATCH 2/3] perf: Use visitor for IndexedScryptoValue --- .../src/types/indexed_value.rs | 88 +++++++++++-------- 1 file changed, 51 insertions(+), 37 deletions(-) diff --git a/radix-engine-interface/src/types/indexed_value.rs b/radix-engine-interface/src/types/indexed_value.rs index 9fb0b83e80e..9048a20d858 100644 --- a/radix-engine-interface/src/types/indexed_value.rs +++ b/radix-engine-interface/src/types/indexed_value.rs @@ -17,52 +17,66 @@ pub struct IndexedScryptoValue { scrypto_value: RefCell>, } +#[derive(Default)] +struct OwnedAndReferenceAggregator { + references: Vec, + owned_nodes: Vec, +} + +impl<'de> UntypedPayloadVisitor<'de, ScryptoCustomTraversal> for OwnedAndReferenceAggregator { + type Output<'t> = Result<(), DecodeError>; + + fn on_terminal_value<'t>( + &mut self, + details: OnTerminalValue<'t, 'de, ScryptoCustomTraversal>, + ) -> core::ops::ControlFlow> { + if let traversal::TerminalValueRef::Custom(custom) = details.value { + match custom.0 { + ScryptoCustomValue::Reference(node_id) => { + self.references.push(node_id.0.into()); + } + ScryptoCustomValue::Own(node_id) => { + self.owned_nodes.push(node_id.0.into()); + } + ScryptoCustomValue::Decimal(_) + | ScryptoCustomValue::PreciseDecimal(_) + | ScryptoCustomValue::NonFungibleLocalId(_) => {} + } + } + core::ops::ControlFlow::Continue(()) + } + + fn on_error<'t>(&mut self, details: OnError<'t, ScryptoCustomTraversal>) -> Self::Output<'t> { + Err(details.error) + } + + fn on_traversal_end<'t>( + &mut self, + _details: OnTraversalEnd<'t, ScryptoCustomTraversal>, + ) -> Self::Output<'t> { + Ok(()) + } +} + impl IndexedScryptoValue { fn new(bytes: Vec) -> Result { - let mut traverser = ScryptoTraverser::new( + let mut aggregates = OwnedAndReferenceAggregator::default(); + ScryptoUntypedTraverser::new( &bytes, - ExpectedStart::PayloadPrefix(SCRYPTO_SBOR_V1_PAYLOAD_PREFIX), - VecTraverserConfig { + UntypedTraverserConfig { max_depth: SCRYPTO_SBOR_V1_MAX_DEPTH, check_exact_end: true, }, - ); - let mut references = Vec::::new(); - let mut owned_nodes = Vec::::new(); - loop { - let event = traverser.next_event(); - match event.event { - TraversalEvent::ContainerStart(_) => {} - TraversalEvent::ContainerEnd(_) => {} - TraversalEvent::TerminalValue(r) => { - if let traversal::TerminalValueRef::Custom(c) = r { - match c.0 { - ScryptoCustomValue::Reference(node_id) => { - references.push(node_id.0.into()); - } - ScryptoCustomValue::Own(node_id) => { - owned_nodes.push(node_id.0.into()); - } - ScryptoCustomValue::Decimal(_) - | ScryptoCustomValue::PreciseDecimal(_) - | ScryptoCustomValue::NonFungibleLocalId(_) => {} - } - } - } - TraversalEvent::TerminalValueBatch(_) => {} - TraversalEvent::End => { - break; - } - TraversalEvent::DecodeError(e) => { - return Err(e); - } - } - } + ) + .run_from_start( + ExpectedStart::PayloadPrefix(SCRYPTO_SBOR_V1_PAYLOAD_PREFIX), + &mut aggregates, + )?; Ok(Self { bytes, - references, - owned_nodes, + references: aggregates.references, + owned_nodes: aggregates.owned_nodes, scrypto_value: RefCell::new(None), }) } From 0d63ac48b0a0bd85feb0351cbb87f9631bd19f47 Mon Sep 17 00:00:00 2001 From: David Edey Date: Wed, 25 Sep 2024 22:57:10 +0100 Subject: [PATCH 3/3] tweak: Move some code around --- radix-common/src/data/scrypto/definitions.rs | 1 + sbor/src/lib.rs | 1 + sbor/src/traversal/mod.rs | 2 + .../{untyped => }/traversal_traits.rs | 80 +++++++++++++++++++ sbor/src/traversal/typed/typed_traverser.rs | 66 +++++++++++++-- sbor/src/traversal/untyped/mod.rs | 2 - 6 files changed, 145 insertions(+), 7 deletions(-) rename sbor/src/traversal/{untyped => }/traversal_traits.rs (59%) diff --git a/radix-common/src/data/scrypto/definitions.rs b/radix-common/src/data/scrypto/definitions.rs index 25c5ffb23b0..5367c577149 100644 --- a/radix-common/src/data/scrypto/definitions.rs +++ b/radix-common/src/data/scrypto/definitions.rs @@ -9,6 +9,7 @@ pub type ScryptoEncoder<'a> = VecEncoder<'a, ScryptoCustomValueKind>; pub type ScryptoDecoder<'a> = VecDecoder<'a, ScryptoCustomValueKind>; #[allow(deprecated)] pub type ScryptoTraverser<'a> = VecTraverser<'a, ScryptoCustomTraversal>; +pub type ScryptoUntypedTraverser<'a> = UntypedTraverser<'a, ScryptoCustomTraversal>; pub type ScryptoValueKind = ValueKind; pub type ScryptoValue = Value; // ScryptoRawValue and friends are defined in custom_payload_wrappers.rs diff --git a/sbor/src/lib.rs b/sbor/src/lib.rs index 5f16fae30a8..0dcd3a2908c 100644 --- a/sbor/src/lib.rs +++ b/sbor/src/lib.rs @@ -123,6 +123,7 @@ pub(crate) mod internal_prelude { // These are mostly used for more advanced use cases, // so aren't included in the general prelude pub use crate::basic::*; + pub use crate::basic_well_known_types::*; pub use crate::decoder::*; pub use crate::encoder::*; pub use crate::payload_validation::*; diff --git a/sbor/src/traversal/mod.rs b/sbor/src/traversal/mod.rs index e8600c6b97d..13193becfea 100644 --- a/sbor/src/traversal/mod.rs +++ b/sbor/src/traversal/mod.rs @@ -1,7 +1,9 @@ mod path_formatting; +mod traversal_traits; mod typed; mod untyped; pub use path_formatting::*; +pub use traversal_traits::*; pub use typed::*; pub use untyped::*; diff --git a/sbor/src/traversal/untyped/traversal_traits.rs b/sbor/src/traversal/traversal_traits.rs similarity index 59% rename from sbor/src/traversal/untyped/traversal_traits.rs rename to sbor/src/traversal/traversal_traits.rs index b51a2c51d12..337ca2c1721 100644 --- a/sbor/src/traversal/untyped/traversal_traits.rs +++ b/sbor/src/traversal/traversal_traits.rs @@ -109,3 +109,83 @@ pub struct OnError<'t, T: CustomTraversal> { pub error: DecodeError, pub location: Location<'t, T>, } + +// We add this allow so that the placeholder names don't have to start with underscores +#[allow(unused_variables)] +pub trait TypedPayloadVisitor<'de, E: CustomExtension> { + type Output<'t, 's> + where + 's: 't; + + #[inline] + #[must_use] + fn on_container_start<'t, 's>( + &mut self, + details: OnContainerStartTyped<'t, 's, E::CustomTraversal>, + ) -> ControlFlow> { + ControlFlow::Continue(()) + } + + #[inline] + #[must_use] + fn on_terminal_value<'t, 's>( + &mut self, + details: OnTerminalValueTyped<'t, 's, 'de, E::CustomTraversal>, + ) -> ControlFlow> { + ControlFlow::Continue(()) + } + + #[inline] + #[must_use] + fn on_terminal_value_batch<'t, 's>( + &mut self, + details: OnTerminalValueBatchTyped<'t, 's, 'de, E::CustomTraversal>, + ) -> ControlFlow> { + ControlFlow::Continue(()) + } + + #[inline] + #[must_use] + fn on_container_end<'t, 's>( + &mut self, + details: OnContainerEndTyped<'t, 's, E::CustomTraversal>, + ) -> ControlFlow> { + ControlFlow::Continue(()) + } + + #[must_use] + fn on_error<'t, 's>(&mut self, details: OnErrorTyped<'t, 's, E>) -> Self::Output<'t, 's>; + + #[must_use] + fn on_traversal_end<'t, 's>(&mut self, details: OnTraversalEndTyped) -> Self::Output<'t, 's>; +} + +pub struct OnContainerStartTyped<'t, 's, T: CustomTraversal> { + pub local_type_id: LocalTypeId, + pub header: ContainerHeader, + pub location: TypedLocation<'t, 's, T>, +} + +pub struct OnTerminalValueTyped<'t, 's, 'de, T: CustomTraversal> { + pub local_type_id: LocalTypeId, + pub value: TerminalValueRef<'de, T>, + pub location: TypedLocation<'t, 's, T>, +} + +pub struct OnTerminalValueBatchTyped<'t, 's, 'de, T: CustomTraversal> { + pub local_type_id: LocalTypeId, + pub value_batch: TerminalValueBatchRef<'de>, + pub location: TypedLocation<'t, 's, T>, +} + +pub struct OnContainerEndTyped<'t, 's, T: CustomTraversal> { + pub local_type_id: LocalTypeId, + pub location: TypedLocation<'t, 's, T>, +} + +pub struct OnTraversalEndTyped {} + +pub struct OnErrorTyped<'t, 's, E: CustomExtension> { + pub error: TypedTraversalError, + pub location: TypedLocation<'t, 's, E::CustomTraversal>, +} diff --git a/sbor/src/traversal/typed/typed_traverser.rs b/sbor/src/traversal/typed/typed_traverser.rs index 43082cbe2b6..1e4fa8af0ee 100644 --- a/sbor/src/traversal/typed/typed_traverser.rs +++ b/sbor/src/traversal/typed/typed_traverser.rs @@ -1,8 +1,4 @@ -use super::*; -use crate::basic_well_known_types::ANY_TYPE; -use crate::rust::prelude::*; -use crate::traversal::*; -use crate::*; +use crate::internal_prelude::*; pub fn traverse_payload_with_types<'de, 's, E: CustomExtension>( payload: &'de [u8], @@ -142,6 +138,66 @@ impl<'de, 's, E: CustomExtension> TypedTraverser<'de, 's, E> { } } + /// Allows migrating off `next_event` before it's removed + pub fn traverse<'t, V: TypedPayloadVisitor<'de, E>>( + &'t mut self, + visitor: &mut V, + ) -> V::Output<'t, 's> { + match self.traverse_internal(visitor) { + ControlFlow::Continue(_) => unreachable!("Never returns a continue"), + ControlFlow::Break(output) => output, + } + } + + fn traverse_internal<'t, V: TypedPayloadVisitor<'de, E>>( + &'t mut self, + visitor: &mut V, + ) -> ControlFlow> { + loop { + // SAFETY: Work around the current borrow checker, which is sound as per this thread: + // https://users.rust-lang.org/t/mutable-borrow-in-loop-borrow-checker-query/118081/3 + // Unsafe syntax borrowed from here: https://docs.rs/polonius-the-crab/latest/polonius_the_crab/ + // Can remove this once the polonius borrow checker hits stable + let fixed_self: &mut TypedTraverser<'de, 's, E> = unsafe { &mut *(self as *mut _) }; + let TypedLocatedTraversalEvent { location, event } = fixed_self.next_event(); + match event { + TypedTraversalEvent::ContainerStart(local_type_id, header) => { + visitor.on_container_start(OnContainerStartTyped { + local_type_id, + header, + location, + })?; + } + TypedTraversalEvent::ContainerEnd(local_type_id, _header) => { + visitor.on_container_end(OnContainerEndTyped { + local_type_id, + location, + })?; + } + TypedTraversalEvent::TerminalValue(local_type_id, value) => { + visitor.on_terminal_value(OnTerminalValueTyped { + local_type_id, + value, + location, + })?; + } + TypedTraversalEvent::TerminalValueBatch(local_type_id, value_batch) => { + visitor.on_terminal_value_batch(OnTerminalValueBatchTyped { + local_type_id, + value_batch, + location, + })?; + } + TypedTraversalEvent::Error(error) => { + ControlFlow::Break(visitor.on_error(OnErrorTyped { error, location }))?; + } + TypedTraversalEvent::End => { + ControlFlow::Break(visitor.on_traversal_end(OnTraversalEndTyped {}))?; + } + } + } + } + pub fn next_event(&mut self) -> TypedLocatedTraversalEvent<'_, 's, 'de, E> { let (typed_event, location) = Self::next_event_internal(&mut self.traverser, &mut self.state); diff --git a/sbor/src/traversal/untyped/mod.rs b/sbor/src/traversal/untyped/mod.rs index 28610ec0a95..4e5ce4fea8d 100644 --- a/sbor/src/traversal/untyped/mod.rs +++ b/sbor/src/traversal/untyped/mod.rs @@ -2,13 +2,11 @@ use crate::internal_prelude::*; mod event_stream_traverser; mod events; -mod traversal_traits; mod untyped_traverser; mod utility_visitors; pub use event_stream_traverser::*; pub use events::*; -pub use traversal_traits::*; pub use untyped_traverser::*; pub use utility_visitors::*;