Skip to content

Commit

Permalink
cleanup MastNodeType serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
plafer committed Jul 12, 2024
1 parent 43cadeb commit 7dec428
Showing 1 changed file with 71 additions and 178 deletions.
249 changes: 71 additions & 178 deletions core/src/mast/serialization/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,28 @@ impl MastNodeType {

impl Serializable for MastNodeType {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
let serialized_bytes = {
let mut serialized_bytes = self.inline_data_to_bytes();
let discriminant = self.discriminant() as u64;
assert!(discriminant <= 0b1111);

// Tag is always placed in the first four bytes
let discriminant = self.discriminant();
assert!(discriminant <= 0b1111);
serialized_bytes[0] |= discriminant << 4;

serialized_bytes
let payload = match self {
MastNodeType::Join {
left_child_id: left,
right_child_id: right,
} => Self::encode_u32_pair(*left, *right),
MastNodeType::Split {
if_branch_id: if_branch,
else_branch_id: else_branch,
} => Self::encode_u32_pair(*if_branch, *else_branch),
MastNodeType::Loop { body_id: body } => Self::encode_u32_payload(*body),
MastNodeType::Block { offset, len } => Self::encode_u32_pair(*offset, *len),
MastNodeType::Call { callee_id } => Self::encode_u32_payload(*callee_id),
MastNodeType::SysCall { callee_id } => Self::encode_u32_payload(*callee_id),
MastNodeType::Dyn => 0,
MastNodeType::External => 0,
};

serialized_bytes.write_into(target)
let value = (discriminant << 60) | payload;
target.write_u64(value);
}
}

Expand All @@ -230,221 +240,104 @@ impl MastNodeType {
unsafe { *<*const _>::from(self).cast::<u8>() }
}

fn inline_data_to_bytes(&self) -> [u8; 8] {
match self {
MastNodeType::Join {
left_child_id: left,
right_child_id: right,
} => Self::encode_u32_pair(*left, *right),
MastNodeType::Split {
if_branch_id: if_branch,
else_branch_id: else_branch,
} => Self::encode_u32_pair(*if_branch, *else_branch),
MastNodeType::Loop { body_id: body } => Self::encode_u32_payload(*body),
MastNodeType::Block { offset, len } => Self::encode_u32_pair(*offset, *len),
MastNodeType::Call { callee_id } => Self::encode_u32_payload(*callee_id),
MastNodeType::SysCall { callee_id } => Self::encode_u32_payload(*callee_id),
MastNodeType::Dyn => [0; 8],
MastNodeType::External => [0; 8],
/// Encodes two u32 numbers in the first 60 bits of a `u64`.
///
/// # Panics
/// - Panics if either `left_value` or `right_value` doesn't fit in 30 bits.
fn encode_u32_pair(left_value: u32, right_value: u32) -> u64 {
if left_value.leading_zeros() < 2 {
panic!(
"MastNodeType::encode_u32_pair: left value doesn't fit in 30 bits: {}",
left_value
);
}
}

fn encode_u32_pair(left_value: u32, right_value: u32) -> [u8; 8] {
assert!(left_value < 2_u32.pow(30));
assert!(right_value < 2_u32.pow(30));

let mut result: [u8; 8] = [0_u8; 8];

// write left child into result
{
let [lsb, a, b, msb] = left_value.to_le_bytes();
result[0] |= lsb >> 4;
result[1] |= lsb << 4;
result[1] |= a >> 4;
result[2] |= a << 4;
result[2] |= b >> 4;
result[3] |= b << 4;

// msb is different from lsb, a and b since its 2 most significant bits are guaranteed
// to be 0, and hence not encoded.
//
// More specifically, let the bits of msb be `00abcdef`. We encode `abcd` in
// `result[3]`, and `ef` as the most significant bits of `result[4]`.
result[3] |= msb >> 2;
result[4] |= msb << 6;
};

// write right child into result
{
// Recall that `result[4]` contains 2 bits from the left child id in the most
// significant bits. Also, the most significant byte of the right child is guaranteed to
// fit in 6 bits. Hence, we use big endian format for the right child id to simplify
// encoding and decoding.
let [msb, a, b, lsb] = right_value.to_be_bytes();

result[4] |= msb;
result[5] = a;
result[6] = b;
result[7] = lsb;
};
if right_value.leading_zeros() < 2 {
panic!(
"MastNodeType::encode_u32_pair: right value doesn't fit in 30 bits: {}",
left_value
);
}

result
((left_value as u64) << 30) | (right_value as u64)
}

fn encode_u32_payload(payload: u32) -> [u8; 8] {
let [payload_byte1, payload_byte2, payload_byte3, payload_byte4] = payload.to_be_bytes();

[0, payload_byte1, payload_byte2, payload_byte3, payload_byte4, 0, 0, 0]
fn encode_u32_payload(payload: u32) -> u64 {
payload as u64
}
}

impl Deserializable for MastNodeType {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let bytes: [u8; 8] = source.read_array()?;
let (discriminant, payload) = {
let value = source.read_u64()?;

let tag = bytes[0] >> 4;
// 4 bits
let discriminant = (value >> 60) as u8;
// 60 bits
let payload = value & 0x0F_FF_FF_FF_FF_FF_FF_FF;

(discriminant, payload)
};

match tag {
match discriminant {
JOIN => {
let (left_child_id, right_child_id) = Self::decode_u32_pair(bytes);
let (left_child_id, right_child_id) = Self::decode_u32_pair(payload);
Ok(Self::Join {
left_child_id,
right_child_id,
})
}
SPLIT => {
let (if_branch_id, else_branch_id) = Self::decode_u32_pair(bytes);
let (if_branch_id, else_branch_id) = Self::decode_u32_pair(payload);
Ok(Self::Split {
if_branch_id,
else_branch_id,
})
}
LOOP => {
let body_id = Self::decode_u32_payload(bytes);
let body_id = Self::decode_u32_payload(payload)?;
Ok(Self::Loop { body_id })
}
BLOCK => {
let (offset, len) = Self::decode_u32_pair(bytes);
let (offset, len) = Self::decode_u32_pair(payload);
Ok(Self::Block { offset, len })
}
CALL => {
let callee_id = Self::decode_u32_payload(bytes);
let callee_id = Self::decode_u32_payload(payload)?;
Ok(Self::Call { callee_id })
}
SYSCALL => {
let callee_id = Self::decode_u32_payload(bytes);
let callee_id = Self::decode_u32_payload(payload)?;
Ok(Self::SysCall { callee_id })
}
DYN => Ok(Self::Dyn),
EXTERNAL => Ok(Self::External),
_ => {
Err(DeserializationError::InvalidValue(format!("Invalid tag for MAST node: {tag}")))
}
_ => Err(DeserializationError::InvalidValue(format!(
"Invalid tag for MAST node: {discriminant}"
))),
}
}
}

/// Deserialization helpers
impl MastNodeType {
fn decode_u32_pair(buffer: [u8; 8]) -> (u32, u32) {
let first = {
let mut first_le_bytes = [0_u8; 4];

first_le_bytes[0] = buffer[0] << 4;
first_le_bytes[0] |= buffer[1] >> 4;

first_le_bytes[1] = buffer[1] << 4;
first_le_bytes[1] |= buffer[2] >> 4;

first_le_bytes[2] = buffer[2] << 4;
first_le_bytes[2] |= buffer[3] >> 4;

first_le_bytes[3] = (buffer[3] & 0b1111) << 2;
first_le_bytes[3] |= buffer[4] >> 6;

u32::from_le_bytes(first_le_bytes)
};

let second = {
let mut second_be_bytes = [0_u8; 4];

second_be_bytes[0] = buffer[4] & 0b0011_1111;
second_be_bytes[1] = buffer[5];
second_be_bytes[2] = buffer[6];
second_be_bytes[3] = buffer[7];

u32::from_be_bytes(second_be_bytes)
};

(first, second)
}

pub fn decode_u32_payload(payload: [u8; 8]) -> u32 {
let payload_be_bytes = [payload[1], payload[2], payload[3], payload[4]];

u32::from_be_bytes(payload_be_bytes)
}
}

// TESTS
// ================================================================================================

#[cfg(test)]
mod tests {
use super::*;
use alloc::vec::Vec;

#[test]
fn mast_node_type_serde_join() {
let left_child_id = 0b00111001_11101011_01101100_11011000;
let right_child_id = 0b00100111_10101010_11111111_11001110;

let mast_node_type = MastNodeType::Join {
left_child_id,
right_child_id,
};

let mut encoded_mast_node_type: Vec<u8> = Vec::new();
mast_node_type.write_into(&mut encoded_mast_node_type);
/// Decodes two `u32` numbers from a 60-bit payload.
fn decode_u32_pair(payload: u64) -> (u32, u32) {
let left_value = (payload >> 30) as u32;
let right_value = (payload & 0x3F_FF_FF_FF) as u32;

// Note: Join's discriminant is 0
let expected_encoded_mast_node_type = [
0b00001101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111,
0b11001110,
];

assert_eq!(expected_encoded_mast_node_type.to_vec(), encoded_mast_node_type);

let (decoded_left, decoded_right) =
MastNodeType::decode_u32_pair(expected_encoded_mast_node_type);
assert_eq!(left_child_id, decoded_left);
assert_eq!(right_child_id, decoded_right);
(left_value, right_value)
}

#[test]
fn mast_node_type_serde_split() {
let if_branch_id = 0b00111001_11101011_01101100_11011000;
let else_branch_id = 0b00100111_10101010_11111111_11001110;

let mast_node_type = MastNodeType::Split {
if_branch_id,
else_branch_id,
};

let mut encoded_mast_node_type: Vec<u8> = Vec::new();
mast_node_type.write_into(&mut encoded_mast_node_type);

// Note: Split's discriminant is 1
let expected_encoded_mast_node_type = [
0b00011101, 0b10000110, 0b11001110, 0b10111110, 0b01100111, 0b10101010, 0b11111111,
0b11001110,
];

assert_eq!(expected_encoded_mast_node_type.to_vec(), encoded_mast_node_type);

let (decoded_if_branch, decoded_else_branch) =
MastNodeType::decode_u32_pair(expected_encoded_mast_node_type);
assert_eq!(if_branch_id, decoded_if_branch);
assert_eq!(else_branch_id, decoded_else_branch);
/// Decodes one `u32` number from a 60-bit payload.
///
/// Returns an error if the payload doesn't fit in a `u32`.
pub fn decode_u32_payload(payload: u64) -> Result<u32, DeserializationError> {
payload.try_into().map_err(|_| {
DeserializationError::InvalidValue(format!(
"Invalid payload: expected to fit in u32, but was {payload}"
))
})
}
}

0 comments on commit 7dec428

Please sign in to comment.