Skip to content

Commit

Permalink
refactor: fix bug in ser/de of inputs and outputs, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Fumuran committed Oct 1, 2024
1 parent 5736aa2 commit 4ea0ac0
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 28 deletions.
22 changes: 10 additions & 12 deletions core/src/stack/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,24 @@ impl IntoIterator for StackInputs {

impl Serializable for StackInputs {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_u8(get_num_stack_values(self));
target.write_many(self.elements);
let num_stack_values = get_num_stack_values(self);
target.write_u8(num_stack_values);
target.write_many(&self.elements[..num_stack_values as usize]);
}
}

impl Deserializable for StackInputs {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let num_elements = source.read_u8()?;

// check that `num_elements` is valid
if num_elements > MIN_STACK_DEPTH as u8 {
return Err(DeserializationError::InvalidValue(format!(
"number of stack elements should not be greater than {}, but {} was found",
MIN_STACK_DEPTH, num_elements
)));
}

let mut elements = source.read_many::<Felt>(num_elements.into())?;
elements.resize(MIN_STACK_DEPTH, ZERO);
elements.reverse();

Ok(StackInputs { elements: elements.try_into().unwrap() })
StackInputs::new(elements).map_err(|_| {
DeserializationError::InvalidValue(format!(
"number of stack elements should not be greater than {}, but {} was found",
MIN_STACK_DEPTH, num_elements
))
})
}
}
3 changes: 3 additions & 0 deletions core/src/stack/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ pub use inputs::StackInputs;
mod outputs;
pub use outputs::StackOutputs;

#[cfg(test)]
mod tests;

// CONSTANTS
// ================================================================================================

Expand Down
32 changes: 16 additions & 16 deletions core/src/stack/outputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@ impl StackOutputs {
/// # Errors
/// Returns an error if:
/// - Any of the provided stack elements are invalid field elements.
pub fn try_from_ints(stack: Vec<u64>) -> Result<Self, OutputError> {
pub fn try_from_ints<I>(iter: I) -> Result<Self, OutputError>
where
I: IntoIterator<Item = u64>,
{
// Validate stack elements
let stack = stack
.iter()
.map(|v| Felt::try_from(*v))
let stack = iter
.into_iter()
.map(Felt::try_from)
.collect::<Result<Vec<Felt>, _>>()
.map_err(OutputError::InvalidStackElement)?;

Expand Down Expand Up @@ -123,26 +126,23 @@ impl From<[Felt; MIN_STACK_DEPTH]> for StackOutputs {

impl Serializable for StackOutputs {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_u8(get_num_stack_values(self));
target.write_many(self.elements);
let num_stack_values = get_num_stack_values(self);
target.write_u8(num_stack_values);
target.write_many(&self.elements[..num_stack_values as usize]);
}
}

impl Deserializable for StackOutputs {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let num_elements = source.read_u8()?;

// check that `num_elements` is valid
if num_elements > MIN_STACK_DEPTH as u8 {
return Err(DeserializationError::InvalidValue(format!(
let elements = source.read_many::<Felt>(num_elements.into())?;

StackOutputs::new(elements).map_err(|_| {
DeserializationError::InvalidValue(format!(
"number of stack elements should not be greater than {}, but {} was found",
MIN_STACK_DEPTH, num_elements
)));
}

let mut elements = source.read_many::<Felt>(num_elements.into())?;
elements.resize(MIN_STACK_DEPTH, ZERO);

Ok(Self { elements: elements.try_into().unwrap() })
))
})
}
}
130 changes: 130 additions & 0 deletions core/src/stack/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use alloc::vec::Vec;

use crate::{
utils::{Deserializable, Serializable},
StackInputs, StackOutputs,
};

// SERDE INPUTS TESTS
// ================================================================================================

#[test]
fn test_inputs_simple() {
let source = Vec::<u64>::from([5, 4, 3, 2, 1]);
let mut serialized = Vec::new();
let inputs = StackInputs::try_from_ints(source.clone()).unwrap();

inputs.write_into(&mut serialized);

let mut expected_serialized = Vec::new();
expected_serialized.push(source.len() as u8);
source
.iter()
.rev()
.for_each(|v| expected_serialized.append(&mut v.to_le_bytes().to_vec()));

assert_eq!(serialized, expected_serialized);

let result = StackInputs::read_from_bytes(&serialized).unwrap();

assert_eq!(*inputs, *result);
}

#[test]
fn test_inputs_full() {
let source = Vec::<u64>::from([16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]);
let mut serialized = Vec::new();
let inputs = StackInputs::try_from_ints(source.clone()).unwrap();

inputs.write_into(&mut serialized);

let mut expected_serialized = Vec::new();
expected_serialized.push(source.len() as u8);
source
.iter()
.rev()
.for_each(|v| expected_serialized.append(&mut v.to_le_bytes().to_vec()));

assert_eq!(serialized, expected_serialized);

let result = StackInputs::read_from_bytes(&serialized).unwrap();

assert_eq!(*inputs, *result);
}

#[test]
fn test_inputs_empty() {
let mut serialized = Vec::new();
let inputs = StackInputs::try_from_ints([]).unwrap();

inputs.write_into(&mut serialized);

let expected_serialized = vec![0];

assert_eq!(serialized, expected_serialized);

let result = StackInputs::read_from_bytes(&serialized).unwrap();

assert_eq!(*inputs, *result);
}

// SERDE OUTPUTS TESTS
// ================================================================================================

#[test]
fn test_outputs_simple() {
let source = Vec::<u64>::from([1, 2, 3, 4, 5]);
let mut serialized = Vec::new();
let inputs = StackOutputs::try_from_ints(source.clone()).unwrap();

inputs.write_into(&mut serialized);

let mut expected_serialized = Vec::new();
expected_serialized.push(source.len() as u8);
source
.iter()
.for_each(|v| expected_serialized.append(&mut v.to_le_bytes().to_vec()));

assert_eq!(serialized, expected_serialized);

let result = StackOutputs::read_from_bytes(&serialized).unwrap();

assert_eq!(*inputs, *result);
}

#[test]
fn test_outputs_full() {
let source = Vec::<u64>::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
let mut serialized = Vec::new();
let inputs = StackOutputs::try_from_ints(source.clone()).unwrap();

inputs.write_into(&mut serialized);

let mut expected_serialized = Vec::new();
expected_serialized.push(source.len() as u8);
source
.iter()
.for_each(|v| expected_serialized.append(&mut v.to_le_bytes().to_vec()));

assert_eq!(serialized, expected_serialized);

let result = StackOutputs::read_from_bytes(&serialized).unwrap();

assert_eq!(*inputs, *result);
}

#[test]
fn test_outputs_empty() {
let mut serialized = Vec::new();
let inputs = StackOutputs::try_from_ints([]).unwrap();

inputs.write_into(&mut serialized);

let expected_serialized = vec![0];

assert_eq!(serialized, expected_serialized);

let result = StackOutputs::read_from_bytes(&serialized).unwrap();

assert_eq!(*inputs, *result);
}

0 comments on commit 4ea0ac0

Please sign in to comment.