diff --git a/trace_decoder/src/core.rs b/trace_decoder/src/core.rs index 3fbed8189..181fa054f 100644 --- a/trace_decoder/src/core.rs +++ b/trace_decoder/src/core.rs @@ -241,11 +241,8 @@ fn start( ) } WireDisposition::Type2 => { - let crate::type2::Frontend { - trie, - code, - collation, - } = crate::type2::frontend(instructions)?; + let crate::type2::Frontend { trie, code } = + crate::type2::frontend(instructions)?; todo!() } diff --git a/trace_decoder/src/type1.rs b/trace_decoder/src/type1.rs index cfa0ed615..bce0c9134 100644 --- a/trace_decoder/src/type1.rs +++ b/trace_decoder/src/type1.rs @@ -12,7 +12,7 @@ use mpt_trie::partial_trie::OnOrphanedHashNode; use nunny::NonEmpty; use u4::U4; -use crate::typed_mpt::{StateMpt, StateTrie as _, StorageTrie, MptKey}; +use crate::typed_mpt::{MptKey, StateMpt, StorageTrie}; use crate::wire::{Instruction, SmtLeaf}; #[derive(Debug, Clone)] @@ -380,6 +380,8 @@ fn finish_stack(v: &mut Vec) -> anyhow::Result { #[test] fn test_tries() { + use crate::typed_mpt::StateTrie as _; + for (ix, case) in serde_json::from_str::>(include_str!("cases/zero_jerigon.json")) .unwrap() diff --git a/trace_decoder/src/type2.rs b/trace_decoder/src/type2.rs index eb70978f9..5d88be6b4 100644 --- a/trace_decoder/src/type2.rs +++ b/trace_decoder/src/type2.rs @@ -1,37 +1,36 @@ //! Frontend for the witness format emitted by e.g [`0xPolygonHermez/cdk-erigon`](https://github.com/0xPolygonHermez/cdk-erigon/) //! Ethereum node. -use std::{ - collections::{HashMap, HashSet}, - iter, -}; +use std::collections::{BTreeMap, HashSet}; use anyhow::{bail, ensure, Context as _}; -use bitvec::vec::BitVec; -use either::Either; -use ethereum_types::BigEndianHash as _; -use itertools::{EitherOrBoth, Itertools as _}; +use ethereum_types::{Address, BigEndianHash as _, U256}; +use itertools::EitherOrBoth; +use keccak_hash::H256; use nunny::NonEmpty; -use plonky2::field::types::Field; +use plonky2::field::types::{Field, Field64 as _}; +use smt_trie::keys::{key_balance, key_code, key_code_length, key_nonce, key_storage}; +use stackstack::Stack; use crate::{ - typed_mpt::StateSmt, + typed_mpt::SmtKey, wire::{Instruction, SmtLeaf, SmtLeafType}, }; type SmtTrie = smt_trie::smt::Smt; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] +/// Combination of all the [`SmtLeaf::node_type`]s +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub struct CollatedLeaf { pub balance: Option, pub nonce: Option, - pub code_hash: Option, - pub storage_root: Option, + pub code: Option, + pub code_length: Option, + pub storage: BTreeMap, } pub struct Frontend { pub trie: SmtTrie, pub code: HashSet>>, - pub collation: HashMap, } /// # Panics @@ -39,13 +38,8 @@ pub struct Frontend { /// NOT call this function on untrusted inputs. pub fn frontend(instructions: impl IntoIterator) -> anyhow::Result { let (node, code) = fold(instructions).context("couldn't fold smt from instructions")?; - let (trie, collation) = - node2trie(node).context("couldn't construct trie and collation from folded node")?; - Ok(Frontend { - trie, - code, - collation, - }) + let trie = node2trie(node).context("couldn't construct trie and collation from folded node")?; + Ok(Frontend { trie, code }) } /// Node in a binary (SMT) tree. @@ -107,9 +101,9 @@ fn fold1(instructions: impl IntoIterator) -> anyhow::Result< Ok(Some(match mask { // note that the single-child bits are reversed... - 0b0001 => Node::Branch(EitherOrBoth::Left(get_child()?)), - 0b0010 => Node::Branch(EitherOrBoth::Right(get_child()?)), - 0b0011 => Node::Branch(EitherOrBoth::Both(get_child()?, get_child()?)), + 0b_01 => Node::Branch(EitherOrBoth::Left(get_child()?)), + 0b_10 => Node::Branch(EitherOrBoth::Right(get_child()?)), + 0b_11 => Node::Branch(EitherOrBoth::Both(get_child()?, get_child()?)), other => bail!("unexpected bit pattern in Branch mask: {:#b}", other), })) } @@ -121,113 +115,111 @@ fn fold1(instructions: impl IntoIterator) -> anyhow::Result< } } -/// Pack a [`Node`] tree into an [`SmtTrie`]. -/// Also summarizes the [`Node::Leaf`]s out-of-band. -/// -/// # Panics -/// - if the tree is too deep. -/// - if [`SmtLeaf::address`] or [`SmtLeaf::value`] are the wrong length. -/// - if [`SmtLeafType::Storage`] is the wrong length. -/// - [`SmtTrie`] panics internally. -fn node2trie( - node: Node, -) -> anyhow::Result<(SmtTrie, HashMap)> { +fn node2trie(node: Node) -> anyhow::Result { let mut trie = SmtTrie::default(); - - let (hashes, leaves) = - iter_leaves(node).partition_map::, Vec<_>, _, _, _>(|(path, leaf)| match leaf { - Either::Left(it) => Either::Left((path, it)), - Either::Right(it) => Either::Right(it), - }); - - let mut lens = std::collections::BTreeMap::<_, usize>::new(); - - for (path, hash) in hashes { - *lens.entry(path.len()).or_default() += 1; - // needs to be called before `set`, below, "to avoid any issues" according - // to the smt docs. + let mut hashes = BTreeMap::new(); + let mut leaves = BTreeMap::new(); + visit(&mut hashes, &mut leaves, Stack::new(), node)?; + for (key, hash) in hashes { trie.set_hash( - bits2bits(path), + key.into_smt_bits(), smt_trie::smt::HashOut { elements: { - let ethereum_types::U256(arr) = ethereum_types::H256(hash).into_uint(); + let ethereum_types::U256(arr) = hash.into_uint(); + for u in arr { + ensure!(u < smt_trie::smt::F::ORDER); + } arr.map(smt_trie::smt::F::from_canonical_u64) }, }, - ) + ); } - dbg!(lens); - - let mut collated = HashMap::::new(); - for SmtLeaf { - node_type, - address, - value, - } in leaves + for ( + addr, + CollatedLeaf { + balance, + nonce, + code, + code_length, + storage, + }, + ) in leaves { - let address = ethereum_types::Address::from_slice(&address); - let collated = collated.entry(address).or_default(); - let value = ethereum_types::U256::from_big_endian(&value); - let key = match node_type { - SmtLeafType::Balance => { - ensure!(collated.balance.is_none(), "double write of field"); - collated.balance = Some(value); - smt_trie::keys::key_balance(address) - } - SmtLeafType::Nonce => { - ensure!(collated.nonce.is_none(), "double write of field"); - collated.nonce = Some(value); - smt_trie::keys::key_nonce(address) - } - SmtLeafType::Code => { - ensure!(collated.code_hash.is_none(), "double write of field"); - collated.code_hash = Some({ - let mut it = ethereum_types::H256::zero(); - value.to_big_endian(it.as_bytes_mut()); - it - }); - smt_trie::keys::key_code(address) - } - SmtLeafType::Storage(it) => { - ensure!(collated.storage_root.is_none(), "double write of field"); - // TODO(0xaatif): https://github.com/0xPolygonZero/zk_evm/issues/275 - // do we not do anything with the storage here? - smt_trie::keys::key_storage(address, ethereum_types::U256::from_big_endian(&it)) + for (value, key_fn) in [ + (balance, key_balance as fn(_) -> _), + (nonce, key_nonce), + (code, key_code), + (code_length, key_code_length), + ] { + if let Some(value) = value { + trie.set(key_fn(addr), value); } - SmtLeafType::CodeLength => smt_trie::keys::key_code_length(address), - }; - trie.set(key, value) - } - Ok((trie, collated)) -} - -/// # Panics -/// - on overcapacity -fn bits2bits(ours: BitVec) -> smt_trie::bits::Bits { - let mut theirs = smt_trie::bits::Bits::empty(); - for it in ours { - theirs.push_bit(it) + } + for (slot, value) in storage { + trie.set(key_storage(addr, slot), value); + } } - theirs + Ok(trie) } -/// Simple, inefficient visitor of all leaves of the [`Node`] tree. -#[allow(clippy::type_complexity)] -fn iter_leaves(node: Node) -> Box)>> { +fn visit( + hashes: &mut BTreeMap, + leaves: &mut BTreeMap, + path: Stack, + node: Node, +) -> anyhow::Result<()> { match node { - Node::Hash(it) => Box::new(iter::once((BitVec::new(), Either::Left(it)))), - Node::Branch(it) => { - let (left, right) = it.left_and_right(); - let left = left - .into_iter() - .flat_map(|it| iter_leaves(*it).update(|(path, _)| path.insert(0, false))); - let right = right - .into_iter() - .flat_map(|it| iter_leaves(*it).update(|(path, _)| path.insert(0, true))); - Box::new(left.chain(right)) + Node::Branch(children) => { + let (left, right) = children.left_and_right(); + if let Some(left) = left { + visit(hashes, leaves, path.pushed(false), *left)?; + } + if let Some(right) = right { + visit(hashes, leaves, path.pushed(true), *right)?; + } + } + Node::Hash(hash) => { + hashes.insert(SmtKey::new(path.iter().copied())?, H256(hash)); + } + Node::Leaf(SmtLeaf { + node_type, + address, // TODO(0xaatif): field should be fixed length + value, // TODO(0xaatif): field should be fixed length + }) => { + let address = Address::from_slice(&address); + let collated = leaves.entry(address).or_default(); + let value = U256::from_big_endian(&value); + macro_rules! ensure { + ($expr:expr) => { + ::anyhow::ensure!($expr, "double write of field for address {}", address) + }; + } + match node_type { + SmtLeafType::Balance => { + ensure!(collated.balance.is_none()); + collated.balance = Some(value) + } + SmtLeafType::Nonce => { + ensure!(collated.nonce.is_none()); + collated.nonce = Some(value) + } + SmtLeafType::Code => { + ensure!(collated.code.is_none()); + collated.code = Some(value) + } + SmtLeafType::Storage(slot) => { + // TODO(0xaatif): ^ field should be fixed length + let clobbered = collated.storage.insert(U256::from_big_endian(&slot), value); + ensure!(clobbered.is_none()) + } + SmtLeafType::CodeLength => { + ensure!(collated.code_length.is_none()); + collated.code_length = Some(value) + } + }; } - Node::Leaf(it) => Box::new(iter::once((BitVec::new(), Either::Right(it)))), } + Ok(()) } #[test] @@ -241,10 +233,10 @@ fn test_tries() { println!("case {}", ix); let instructions = crate::wire::parse(&case.bytes).unwrap(); let frontend = frontend(instructions).unwrap(); - // assert_eq!(case.expected_state_root, { - // let mut it = [0; 32]; - // smt_trie::utils::hashout2u(frontend.trie.root).to_big_endian(&mut - // it); ethereum_types::H256(it) - // }); + assert_eq!(case.expected_state_root, { + let mut it = [0; 32]; + smt_trie::utils::hashout2u(frontend.trie.root).to_big_endian(&mut it); + ethereum_types::H256(it) + }); } } diff --git a/trace_decoder/src/typed_mpt.rs b/trace_decoder/src/typed_mpt.rs index 4df34891a..fa14f2ffc 100644 --- a/trace_decoder/src/typed_mpt.rs +++ b/trace_decoder/src/typed_mpt.rs @@ -244,7 +244,7 @@ impl SmtKey { Ok(Self { bits, len }) } - fn into_bits(self) -> smt_trie::bits::Bits { + pub fn into_smt_bits(self) -> smt_trie::bits::Bits { let mut bits = smt_trie::bits::Bits::default(); for bit in self.as_bitslice() { bits.push_bit(*bit) @@ -522,7 +522,7 @@ impl StateSmt { } = self; let mut smt = smt_trie::smt::Smt::::default(); for (k, v) in hashed_out { - smt.set_hash(k.into_bits(), conv_hash::eth2smt(*v)); + smt.set_hash(k.into_smt_bits(), conv_hash::eth2smt(*v)); } for ( addr, diff --git a/trace_decoder/src/wire.rs b/trace_decoder/src/wire.rs index 6f56f1e44..9d1a7fb10 100644 --- a/trace_decoder/src/wire.rs +++ b/trace_decoder/src/wire.rs @@ -82,8 +82,8 @@ pub enum Instruction { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct SmtLeaf { pub node_type: SmtLeafType, - pub address: NonEmpty>, - pub value: NonEmpty>, + pub address: NonEmpty>, // TODO(0xaatif): this should be a fixed length + pub value: NonEmpty>, // TODO(0xaatif): this should be a fixed length } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -91,7 +91,7 @@ pub enum SmtLeafType { Balance, Nonce, Code, - Storage(NonEmpty>), + Storage(NonEmpty>), // TODO(0xaatif): this should be a fixed length CodeLength, }