diff --git a/src/stdlib/mod.rs b/src/stdlib/mod.rs index 1c8dde950..b4389ce3b 100644 --- a/src/stdlib/mod.rs +++ b/src/stdlib/mod.rs @@ -99,7 +99,7 @@ pub fn init_stdlib_dep( server_mode: &mut Option, ) -> usize { // list the stdlib dependency in order - let libs = vec!["bits", "comparator", "multiplexer", "mimc", "int"]; + let libs = vec!["bits", "comparator", "multiplexer", "mimc", "int", "smt"]; let mut node_id = node_id; diff --git a/src/stdlib/native/smt/lib.no b/src/stdlib/native/smt/lib.no new file mode 100644 index 000000000..349d9f8c0 --- /dev/null +++ b/src/stdlib/native/smt/lib.no @@ -0,0 +1,434 @@ +use std::mimc; +use std::bits; + +// SMT +// This is an implementation of sparse merkle tree according to the circom specification +// learn more about it here - https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf + +// A new_root computation is computed as follows +// state +// top old_root new_root +// / \ / \ +// top [sibling] [] [sibling] [] +// / \ / \ +// new [sibling] [old_key] <- [sibling] [] --> this will change in case of update or del +// The key we are doing ----------- / \ +// operation follows this path and this key will be [old_key] [new-key] + +// generating member ship or (non-)memberships proofs is a more simpler than computing new roots we just compute roots using the siblings and keys , values +// and verify whether the computed roots is equal to the root of the tree + + + + +// State Constants for computing new roots +const st_top = 0; +const st_upd = 1; +const st_old_zero = 2; +const st_new = 3; +const st_btn = 4; +const st_na = 5; + + +// State Constants for the verification states +const v_st_top= 0; +const v_st_iz = 1; +const v_st_inz = 2; +const v_st_inew = 3; +const v_st_na = 4; + + + +/// Computes the hash for a leaf in the Sparse Merkle tree +/// `hash = H ( 1 | key | value )` +/// +/// # Parameters +/// `key`: key for the node +/// `value`: Value for the node +/// +/// # Returns +/// Field: mimc7 hash for key = 1 and values = [key,value] +fn compute_leaf_hash(key: Field, value: Field) -> Field { + let hash_values = [key, value]; + return mimc::mimc7_hash(hash_values, 1); +} + +/// Computes the hash for the internal nodes in the SMT +/// `Hash = H(lh | rh)` +/// +/// # Parameters +/// `left_child_hash`: The hash of the left child node +/// `right_child_hash`: The hash of the right child node +/// +/// # Returns +/// Field: mimc7 hash for key = 0 and values = [lh , rh] +fn compute_internal_node_hash(left_child_hash: Field , right_child_hash: Field) -> Field { + let hash_values = [left_child_hash,right_child_hash]; + return mimc::mimc7_hash(hash_values,0); +} + + +/// Computers the xor of two bits +fn calculate_xor(bit1: Bool, bit2: Bool) -> Bool{ + return ((bit1 && !bit2) || (!bit1 && bit2)); +} +/// computes the value of boolean +fn bool_to_field(bool: Bool) -> Field { + let zero = 0; + let one = 1; + let val = if bool { one } else { zero }; + return val; +} + +/// switches left and right inputs based on the insert sel bool +/// if self bool is true we switch l and r else same +fn switcher(sel: Bool,left: Field, right: Field) -> [Field;2] { + let mut outputs = [0;2]; + outputs[0] = if sel { right } else { left }; + outputs[1] = if sel { left } else { right }; + return outputs; +} + +/// This function get the level where the key will be inserted. level 0 is the root , 1 is next and so on.. +/// To find this level all the child level must have a sibling of 0 and +/// the parent level has a sibling != 0. The root is assumed to have a parent level with a sibling != 0 +fn level_inserted(siblings: [Field;LEN]) -> [Bool;LEN] { + let mut level_inserted = [false;LEN]; + let mut done = [false; LEN]; + + + level_inserted[(LEN - 1)] = !(siblings[LEN- 2] == 0); + done[LEN - 2] = level_inserted[LEN - 1]; + + for idx in 1..(LEN-1) { + let is_sibling_zero = siblings[(LEN - 2) - idx] == 0; + level_inserted[(LEN - 1) - idx] = (!done[(LEN - 1) - idx]) && (!is_sibling_zero); + done[(LEN-2) - idx] = level_inserted[(LEN - 1) - idx] || done[(LEN - 1) - idx]; + } + level_inserted[0] = !done[0]; + return level_inserted; +} + + +/// Only 1 state boolean should be true at a time as a level can only have 1 state boolean +fn assert_state_valid(state: [Bool;6]) { + let mut sum = 0; + for idx in 0..6 { + sum = sum + bool_to_field(state[idx]); + } + assert_eq(sum , 1); +} + +/// State transitions is as follows +/// `top` : The insert level for the key which we are replacing (in our path )is not reached i.e bits are the same for old key and new key +/// `upd` : The insert level is reached and we are performing an update operation on the tree +/// `old_zero` : The insert level is reached and the old key is zero +/// `btn` : The insert level is reached but the xor of old key and new key is same i.e they are on the same +/// `new` : The insert level is reached the xor of old key and new key is different therefore we add the new leaf and old leaf here +/// `na` : All the change to the tree has happend and now we just don't care about the states +/// calculates the next state for the levels when computing new roots +/// # Parameters +/// `prev_level_state`: state of the previous level +/// `xor`: the xor value for the old key and new key if same we just add a sibling +/// `is_old_zero`: indicates whethere we are inserting in an zero leaf +/// `level_inserted`: indicates that is this the level inserted +/// `is_update`: whether the operation is an update operation +/// # Returns +/// `[Bool;6]`: the next state +fn next_state( + prev_level_state: [Bool;6], + xor: Bool, + is_old_zero: Bool, + level_inserted: Bool, + is_update: Bool, +) -> [Bool;6] { + let mut new_states = [false;6]; + new_states[st_na] = prev_level_state[st_na] || ((!prev_level_state[st_top]) && (!prev_level_state[st_btn])); + new_states[st_top] = prev_level_state[st_top] && (!level_inserted); + new_states[st_upd] = prev_level_state[st_top] && (level_inserted && is_update); + new_states[st_old_zero] = prev_level_state[st_top] && (is_old_zero && level_inserted); + new_states[st_new] = ((prev_level_state[st_top] && ((level_inserted && xor) && !is_old_zero) ) ||(prev_level_state[st_btn] && xor)) && (!is_update); + new_states[st_btn] = ((prev_level_state[st_top] && ((level_inserted && !xor) && !is_old_zero) )||((prev_level_state[st_btn] && !xor))) && (!is_update); + assert_state_valid(new_states); + return new_states; +} + + + + +/// Computes the roots given old leafs and new leafs for different level based on the state +fn compute_roots( + old_leaf: Field, + new_leaf: Field, + old_child: Field, + new_child: Field, + sibling: Field, + new_lrbit: Bool, // left right bit of where the new leaf goes + state: [Bool;6], +) -> [Field;2] { + let is_btn = bool_to_field(state[st_btn]); + let is_top = bool_to_field(state[st_top]); + let is_new = bool_to_field(state[st_new]); + let is_upd = bool_to_field(state[st_upd]); + let is_old_zero = bool_to_field(state[st_old_zero]); + + + let is_top_or_btn = is_top + is_btn; + let is_top_or_btn_or_new = is_top_or_btn + is_new; + + let is_leaf = is_btn + (is_new + is_upd); + let old_chilren = switcher(new_lrbit , old_child , sibling); + let node_val = compute_internal_node_hash(old_chilren[0],old_chilren[1]); + + // if it is a leaf the root becomes the old_leaf and node_val in case of top i.e we are above the level + // where old key was inserted + let old_root = (is_leaf * old_leaf ) + (node_val * is_top); + + let new_leaf_cond = is_upd + is_old_zero; + + // did not use if else here as I need it to be zero also if the conditions are not met + // calculate left child and right child based on new_lrbit = 1 i.e true + let possible_left_child = (is_top_or_btn * new_child) + (is_new * new_leaf); + let possible_right_child = (is_top * sibling) + (is_new * old_leaf); + + + let new_children = switcher(new_lrbit , possible_left_child, possible_right_child); + let new_node_val = compute_internal_node_hash(new_children[0],new_children[1]); + let new_root = (new_node_val*is_top_or_btn_or_new) + (new_leaf * new_leaf_cond); + + return [ old_root, new_root]; +} + + +/// It computes a new root for the smt given an operations +/// as for every hash as the bottom level there is a node +/// # Parameteres +/// - `old_root` : the root before the operation +/// - `old_key` : the key we are replacing or which is in our path +/// - `old_value`: the value for the old key +/// - `new_key` : the new key in case of insert +/// - `new_value`: the new value for the key +/// - `siblings` : siblings for the path +/// - `isold_zero` : whether we are inserting in zero leaf +/// - `operation`: what operation we are doing +/// 0 - Insert +/// 1 - Update +/// 2 - Delete +/// # Returns +/// - `Field` The new computed root +fn compute_new_root( + old_root: Field, + old_key: Field, + old_value: Field, + new_key: Field, + new_value: Field, + siblings: [Field; LEN], + isold_zero : Bool, + operation: Field +) -> Field { + + // if it is a update then old and new key should be equal + let is_upd = bool_to_field(operation == 1); + assert_eq( (is_upd * old_key) , (is_upd * new_key) ); + + let old_leaf = compute_leaf_hash(old_key,old_value); + let new_leaf = compute_leaf_hash(new_key,new_value); + + let old_key_bits = bits::to_bits(LEN , old_key); + let new_key_bits = bits::to_bits(LEN , new_key); + + let mut level_inserted = level_inserted(siblings); + let mut states = [[false;6]; LEN]; + + // state above the root_state is assumed to be the top state + let mut abv_root_state = [true , false, false, false, false, false]; + states[0] = next_state( + abv_root_state, + false, + isold_zero, + level_inserted[0], + operation == 1 + ); + + // compute states of each level the reason for different loops is that in this loop we transition from top to other states + // as while computing the roots we need to know the child of the level below thus both the loops run in the opposite direction + for idx in 1..LEN { + let xor = calculate_xor(old_key_bits[idx] , new_key_bits[idx]); + states[idx] = next_state( + states[idx - 1], + xor, + isold_zero, + level_inserted[idx], + operation == 1 + ); + } + // last state should be na or new2 or old2 or upd i.e it should not be top or btn + let last_state_valid = (!states[LEN - 1][st_btn]) || (!states[LEN-1][st_top]); + assert(last_state_valid); + + let mut roots = [0,0]; + + // compute the roots for each level the child for the level is the last computed root + for idx in 0..LEN { + let new_roots = compute_roots( + old_leaf, + new_leaf, + roots[0], + roots[1], + siblings[(LEN - 1) - idx], + new_key_bits[(LEN - 1) -idx], + states[(LEN - 1) - idx] + ); + roots = new_roots; + } + + // if there is delete then switch the old root and new root + let is_delete = operation == 2; + roots = switcher(is_delete , roots[0] , roots[1]); + + // old root calcualted and old root given should be same + assert_eq(roots[0] , old_root); + return roots[1]; +} + + + +// ========= VERIFICAION ===== + +fn assert_vstate_valid(state:[Bool;5]) { + let mut sum = 0; + for ii in 0..5 { + sum = sum + bool_to_field(state[ii]); + } + assert_eq(sum , 1); +} + +/// State transitions for verification is as follows +/// `top` : The insert level for our key is not reached +/// `iz` : level inserted reach and the notFound key,val are zero and we have to verify non inclustion +/// `inz` : level inserted reach and the notFound key,val are non zero and we have to verify non inclustion +/// `inew`: level inserated reached and we have verify inclusion +/// `na` : All the verification has been done and now we just don't care about the states +/// when we generate (non-)membership proofs we have a state for each level this is helpful in computing the root +/// # Parameters +/// `prev_level_state`: state of the previous level +/// `is_nf_zero`: whether the not found key was zero i.e our key was inserted in a zero leaf or not +/// `is_insert_level`: whether the level inserted for our key has been reached +/// `inclusion_proof`: whether we are generating a membership proof or non-membership proof +/// # Returns +/// `[Bool;5]` : the next state +fn v_next_state( + prev_level_state: [Bool;5], + is_nf_zero: Bool, + is_insert_level: Bool, + inclusion_proof: Bool +) -> [Bool;5] { + let mut new_states = [false;5]; + new_states[v_st_na] = prev_level_state[v_st_na] || (!prev_level_state[v_st_top]); + new_states[v_st_top] = prev_level_state[v_st_top] && (!is_insert_level); + new_states[v_st_iz] = prev_level_state[v_st_top] && ((is_nf_zero && is_insert_level) && !inclusion_proof); + new_states[v_st_inz] = prev_level_state[v_st_top] && ((is_insert_level && (!is_nf_zero)) && !inclusion_proof); + new_states[v_st_inew] = prev_level_state[v_st_top] && (is_insert_level && inclusion_proof); + assert_vstate_valid(new_states); + return new_states; +} + + +/// compute the root for verification given the state and the siblings +/// # Parameters +/// `state`: the state for the level +/// `child`: the child of the level as we iterate from the bottom to the top we pass in the last computed root +/// `sibling`: the sibling for the key +/// `lrbit`: the bit for the key deciding whether it is left or right in the path +/// `not_found_leaf`: the leaf hash for the not found key +/// `leaf`: the leaf hash for the key we are creating (non-)membership proof for +fn compute_root( + state: [Bool;5], + child: Field, + sibling: Field, + lrbit: Bool, + not_found_leaf: Field, + leaf: Field +) -> Field { + let children = switcher(lrbit, child, sibling); + let node_hash = compute_internal_node_hash(children[0],children[1]); + + let is_top = bool_to_field(state[v_st_top]); + let is_inz = bool_to_field(state[v_st_inz]); + let is_inew = bool_to_field(state[v_st_inew]); + + let root = (is_top * node_hash) + ((is_inz * not_found_leaf) + (is_inew * leaf)); + return root; +} + + +/// Verifies the (non-)membership proof for the SMT +/// # Parameters +/// `root`: the root of the tree +/// `siblings`: the siblings for the path +/// `not_found_key`: when testing for non membership proofs the key and val found where the (key,val) should be +/// `not_found_val`: the value for the not found key +/// `inclusion_proof`: whether we are verifying a membership proof or non-membership proof +/// `is_nf_zero`: whether the not found key i.e the key val where the (key,val) should be is zero +/// `key`: the key for which we are verifying the proof +/// `value`: the value for the key +fn verify( + root: Field, + siblings: [Field;LEN], + not_found_key: Field, // + not_found_val: Field, + inclusion_proof: Bool, // true for memebership proofs and false for non-membership proofs + is_nf_zero: Bool, + key: Field, + value: Field +){ + + // for non inclusion proofs and is_nf_zero false `not_found_key` and `key` should not be equal + let valid_keys = (inclusion_proof || (not_found_key != key)) && !is_nf_zero; + assert(valid_keys); + + let not_found_leaf = compute_leaf_hash(not_found_key,not_found_val); + let leaf = compute_leaf_hash(key,value); + + let nfk_bits = bits::to_bits(LEN , not_found_key); + let key_bits = bits::to_bits(LEN , key); + + let insertion_level = level_inserted(siblings); + + let mut states = [[false;5];LEN]; + + // state above the root is assumed to be in top state + let abv_root_state = [true , false, false, false, false]; + states[0] = v_next_state( + abv_root_state, + is_nf_zero, + insertion_level[0], + inclusion_proof + ); + + for idx in 1..LEN { + states[idx] = v_next_state( + states[idx - 1], + is_nf_zero, + insertion_level[idx], + inclusion_proof + ); + } + let last_state_valid = (!states[LEN-1][v_st_top]); + let mut computed_root = 0; + + for idx in 0..LEN { + let new_root = compute_root( + states[(LEN - 1) - idx], + computed_root, + siblings[(LEN - 1) - idx], + key_bits[(LEN - 1) - idx], + not_found_leaf, + leaf + ); + computed_root = new_root; + } + + // whether the computed root is equal to the root of the tree + assert_eq(computed_root , root); +} \ No newline at end of file diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs index 278366d40..8d4145005 100644 --- a/src/tests/stdlib/mod.rs +++ b/src/tests/stdlib/mod.rs @@ -3,6 +3,8 @@ mod mimc; mod multiplexer; mod uints; +mod smt; + use std::{path::Path, str::FromStr}; use crate::{ diff --git a/src/tests/stdlib/smt/mod.rs b/src/tests/stdlib/smt/mod.rs new file mode 100644 index 000000000..5945f6665 --- /dev/null +++ b/src/tests/stdlib/smt/mod.rs @@ -0,0 +1,197 @@ +use super::test_stdlib; +use crate::error::Result; +use rstest::rstest; + +#[rstest] +#[case::empty_insert( + "111", + "222", + vec![], + vec![], + "0", + "0", + "0", + true, + "0", + "10729541595941744696255200734832925648647334864637393545770039840405438557214" +)] +#[case::insert( + "9700", + "8800", + vec![ + "11516567903461282088126784254894078034845453066905529710360444678685967109731", + "20659815156440161169257728848234717083009297735150715641331129637520746075208", + "9913888994783849052153109228696667088493732024315564088242255883696537997181", + "5846636767743912144636001040109151354569603553960606470003432371719273838746", + "11187195831226248797031600213284009205392316257240308867007036461570604014316", + "14862507023886888124421673098712328314603512225968198646585391849514658546416" + ], + vec![0 , 1, 2, 3, 4, 5], + "17581302140303159455912973258196037026284300302708949996108423583963947226858", + "3492", + "3168", + false, + "0", + "17713115569734927279694966589149598343072771273196461675493145305311861382883" +)] +#[case::delete( + "111", + "222", + vec!["9220985749551237028296517339018427057953245762011653459076210336571800515245"], + vec![1], + "21805692344774694518236557976212317500166230062431619316104692181032186605312", + "555", + "666", + false, + "2", + "3039938863220546817637150518308754073715763397170404924604005494776416658714" +)] +#[case::update( + "555", + "777", + vec!["9220985749551237028296517339018427057953245762011653459076210336571800515245"], + vec![1], + "3039938863220546817637150518308754073715763397170404924604005494776416658714", + "555", + "666", + false, + "1", // update operation + "18811865073273086230239721237564240209328819936273238864031238045766843861603" +)] +fn test_smt_cr( + #[case] key: &str, + #[case] val: &str, + #[case] non_zero_sibling: Vec<&str>, + #[case] non_zero_sibling_index: Vec, + #[case] old_root: &str, + #[case] old_key: &str, + #[case] old_val: &str, + #[case] is_old_zero: bool, + #[case] operation: &str, + #[case] expected_new_root: &str, +) -> Result<()> { + let mut siblings = vec!["0"; 254]; + + for (&index, sibling) in non_zero_sibling_index.iter().zip(non_zero_sibling) { + siblings[index] = sibling; + } + + let public_inputs = format!( + r#"{{"siblings": {:?}}}"#, + siblings + .iter() + .map(|ele| ele.to_string()) + .collect::>() + ); + + let mut values = vec!["0"; 7]; // 0 , 1 -> key ,val + values[0] = old_root; + values[1] = key; + values[2] = val; + values[3] = old_key; + values[4] = old_val; + values[5] = if is_old_zero { "0" } else { "1" }; + values[6] = operation; + + let private_inputs = format!( + r#"{{"values": {:?}}}"#, + values + .iter() + .map(|ele| ele.to_string()) + .collect::>() + ); + + test_stdlib( + "smt/smt_main.no", + None, + &public_inputs, + &private_inputs, + vec![expected_new_root], + )?; + Ok(()) +} + +#[rstest] +#[case::inclusion( + "333", + "444", + vec![ + "15403437905133579310679669358298285751036324375519574557984500284974195012647", + "16164523410687895121172017182256869209088533188202760284238496207325271948775" + ], + vec![0,1], + "12941802777540120349830076641367475813359582080712967896858975038182858131027", + "0", + "0", + false, + true, +)] +#[case::exclusion( + "1000", + "0", + vec![ + "5004112844904397918413167045606564570413835725979211272408079893204730422053", + "20870930364208425904173849538077157932823107157308028777399852634164408184090", + "18817399965850323578786675877025159015083291330173277928593283742904067184537", + "11475507857885337462985742557005542752995566264656500462775773988511382189430", + "20591657041708931641763347242286558192823550076918037187689654551559790721676", + "8391178249010813208860647414946215155510772994793073739371291818862143236795" + ], + vec![0,1,2,3,4,5], + "12941802777540120349830076641367475813359582080712967896858975038182858131027", + "1960", + "1760", + false, + false, +)] +fn test_smt_ie( + #[case] key: &str, + #[case] val: &str, + #[case] non_zero_sibling: Vec<&str>, + #[case] non_zero_sibling_index: Vec, + #[case] root: &str, + #[case] not_found_key: &str, + #[case] not_found_val: &str, + #[case] is_old_zero: bool, + #[case] inclusion_proof: bool, +) -> Result<()> { + let mut siblings = vec!["0"; 254]; + + for (&index, sibling) in non_zero_sibling_index.iter().zip(non_zero_sibling) { + siblings[index] = sibling; + } + + let public_inputs = format!( + r#"{{"siblings": {:?}}}"#, + siblings + .iter() + .map(|ele| ele.to_string()) + .collect::>() + ); + + let mut values = vec!["0"; 7]; // 0 , 1 -> key ,val + values[0] = root; + values[1] = not_found_key; + values[2] = not_found_val; + values[3] = key; + values[4] = val; + values[5] = if is_old_zero { "0" } else { "1" }; + values[6] = if inclusion_proof { "0" } else { "1" }; + + let private_inputs = format!( + r#"{{"values": {:?}}}"#, + values + .iter() + .map(|ele| ele.to_string()) + .collect::>() + ); + + test_stdlib( + "smt/smt_verify.no", + None, + &public_inputs, + &private_inputs, + vec![], + )?; + Ok(()) +} diff --git a/src/tests/stdlib/smt/smt_main.no b/src/tests/stdlib/smt/smt_main.no new file mode 100644 index 000000000..9265d4424 --- /dev/null +++ b/src/tests/stdlib/smt/smt_main.no @@ -0,0 +1,22 @@ +use std::smt; + +fn main(pub siblings: [Field;254], values: [Field;7]) -> Field { + let old_root = values[0]; + let old_key = values[3]; + let old_val = values[4]; + let key = values[1]; + let val = values[2]; + let operation = values[6]; // true false -> insert + let is_old_zero = values[5] == 0; + let new_root = smt::compute_new_root( + old_root, + old_key, + old_val, + key, + val, + siblings, + is_old_zero, + operation + ); + return new_root; +} \ No newline at end of file diff --git a/src/tests/stdlib/smt/smt_verify.no b/src/tests/stdlib/smt/smt_verify.no new file mode 100644 index 000000000..bf7fa5eb0 --- /dev/null +++ b/src/tests/stdlib/smt/smt_verify.no @@ -0,0 +1,21 @@ +use std::smt; + +fn main(pub siblings: [Field;254], values: [Field;7]) { + let root = values[0]; + let not_found_key = values[1]; + let not_found_val = values[2]; + let key = values[3]; + let value = values[4]; + let is_old_zero = values[5] == 0; + let inclusion_proof = values[6] == 0; + smt::verify( + root, + siblings, + not_found_key, // when testing for non membership proofs the key and val found where the (key,val) should be + not_found_val, + inclusion_proof, // true for memebership proofs and false for non-membership proofs + is_old_zero, + key, + value + ); +} \ No newline at end of file