Skip to content

Commit

Permalink
fix: enable the mpt cache (risc0#62)
Browse files Browse the repository at this point in the history
* Refactor MptNode struct and methods

* fix: avoid RefCell cross-await
  • Loading branch information
johntaiko authored Mar 19, 2024
1 parent 3268d70 commit 46825d6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 29 deletions.
40 changes: 14 additions & 26 deletions primitives/src/mpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use core::{
cell::RefCell,
cmp,
fmt::{Debug, Write},
iter, mem,
Expand Down Expand Up @@ -131,10 +132,10 @@ pub fn keccak(data: impl AsRef<[u8]>) -> [u8; 32] {
pub struct MptNode {
/// The type and data of the node.
data: MptNodeData,
// / Cache for a previously computed reference of this node. This is skipped during
// / serialization.
// #[serde(skip)]
// cached_reference: RefCell<Option<MptNodeReference>>,
/// Cache for a previously computed reference of this node. This is skipped during
/// serialization.
#[serde(skip)]
cached_reference: RefCell<Option<MptNodeReference>>,
}

/// Represents custom error types for the sparse Merkle Patricia Trie (MPT).
Expand Down Expand Up @@ -209,7 +210,7 @@ impl From<MptNodeData> for MptNode {
fn from(value: MptNodeData) -> Self {
Self {
data: value,
// cached_reference: RefCell::new(None),
cached_reference: RefCell::new(None),
}
}
}
Expand Down Expand Up @@ -371,11 +372,10 @@ impl MptNode {
/// storage or transmission purposes.
#[inline]
pub fn reference(&self) -> MptNodeReference {
// self.cached_reference
// .borrow_mut()
// .get_or_insert_with(|| self.calc_reference())
// .clone()
self.calc_reference()
self.cached_reference
.borrow_mut()
.get_or_insert_with(|| self.calc_reference())
.clone()
}

/// Computes and returns the 256-bit hash of the node.
Expand All @@ -385,11 +385,7 @@ impl MptNode {
pub fn hash(&self) -> B256 {
match self.data {
MptNodeData::Null => EMPTY_ROOT,
// _ => match self
// .cached_reference
// .borrow_mut()
// .get_or_insert_with(|| self.calc_reference())
_ => match self.calc_reference() {
_ => match self.reference() {
MptNodeReference::Digest(digest) => digest,
MptNodeReference::Bytes(bytes) => keccak(bytes).into(),
},
Expand All @@ -398,11 +394,7 @@ impl MptNode {

/// Encodes the [MptNodeReference] of this node into the `out` buffer.
fn reference_encode(&self, out: &mut dyn alloy_rlp::BufMut) {
// match self
// .cached_reference
// .borrow_mut()
// .get_or_insert_with(|| self.calc_reference())
match self.calc_reference() {
match self.reference() {
// if the reference is an RLP-encoded byte slice, copy it directly
MptNodeReference::Bytes(bytes) => out.put_slice(&bytes),
// if the reference is a digest, RLP-encode it with its fixed known length
Expand All @@ -415,11 +407,7 @@ impl MptNode {

/// Returns the length of the encoded [MptNodeReference] of this node.
fn reference_length(&self) -> usize {
// match self
// .cached_reference
// .borrow_mut()
// .get_or_insert_with(|| self.calc_reference())
match self.calc_reference() {
match self.reference() {
MptNodeReference::Bytes(bytes) => bytes.len(),
MptNodeReference::Digest(_) => 1 + 32,
}
Expand Down Expand Up @@ -774,7 +762,7 @@ impl MptNode {
}

fn invalidate_ref_cache(&mut self) {
// self.cached_reference.borrow_mut().take();
self.cached_reference.borrow_mut().take();
}

/// Returns the number of traversable nodes in the trie.
Expand Down
6 changes: 3 additions & 3 deletions raiko-host/src/prover/proof/risc0/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ pub async fn execute_risc0(
req: &Risc0ProofParams,
) -> Result<Risc0Response, String> {
println!("elf code length: {}", RISC0_METHODS_ELF.len());
let encoded_input = to_vec(&input).expect("Could not serialize proving input!");

let result = maybe_prove::<GuestInput<EthereumTxEssence>, GuestOutput>(
req,
&input,
encoded_input,
RISC0_METHODS_ELF,
&output,
Default::default(),
Expand Down Expand Up @@ -227,13 +228,12 @@ pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(

pub async fn maybe_prove<I: Serialize, O: Eq + Debug + Serialize + DeserializeOwned>(
req: &Risc0ProofParams,
input: &I,
encoded_input: Vec<u32>,
elf: &[u8],
expected_output: &O,
assumptions: (Vec<Assumption>, Vec<String>),
) -> Option<(String, Receipt)> {
let (assumption_instances, assumption_uuids) = assumptions;
let encoded_input = to_vec(input).expect("Could not serialize proving input!");

let encoded_output =
to_vec(expected_output).expect("Could not serialize expected proving output!");
Expand Down

0 comments on commit 46825d6

Please sign in to comment.