Skip to content

Commit

Permalink
refactor: simple code cleanups in mst implementation (#269)
Browse files Browse the repository at this point in the history
* refactor: simple code cleanups in mst implementation

* split ci tests workflow in smaller jobs

* fix tests

* update parallel tests + run tests on aws only for summa team

* merge github workflows
  • Loading branch information
teddav authored Jul 11, 2024
1 parent 5237346 commit 23587ad
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 95 deletions.
58 changes: 56 additions & 2 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ env:

jobs:
wakeup:
if: github.event.pull_request.head.repo.full_name == 'summa-dev/summa-solvency'
runs-on: ubuntu-latest
permissions:
id-token: write
Expand All @@ -31,9 +32,10 @@ jobs:
aws-region: us-west-2

- name: Wakeup runner
run: .github/scripts/wakeup.sh
run: .github/scripts/wakeup.sh

build:
if: github.event.pull_request.head.repo.full_name == 'summa-dev/summa-solvency'
runs-on: [summa-solvency-runner]
needs: [wakeup]

Expand Down Expand Up @@ -71,4 +73,56 @@ jobs:
run: |
cd backend
cargo run --release --example summa_solvency_flow
test-zk-prover:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Test Zk Prover
run: |
cd zk_prover
cargo test --release --features dev-graph -- --nocapture
test-zk-prover-examples:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install solc
run: (hash svm 2>/dev/null || cargo install --version 0.2.23 svm-rs) && svm install 0.8.20 && solc --version
- name: Test Zk Prover examples
run: |
cd zk_prover
cargo run --release --example gen_inclusion_verifier
cargo run --release --example gen_commitment
cargo run --release --example gen_inclusion_proof
test-zk-prover-examples-nova:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Test Zk Prover examples
run: |
cd zk_prover
cargo run --release --example nova_incremental_verifier
test-backend:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install Foundry
uses: foundry-rs/foundry-toolchain@v1
- name: Test backend
run: |
cd backend
cargo test --release -- --nocapture
test-backend-examples:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install Foundry
uses: foundry-rs/foundry-toolchain@v1
- name: Test backend example
run: |
cd backend
cargo run --release --example summa_solvency_flow
2 changes: 1 addition & 1 deletion backend/src/apis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ where
.map(|balance| BigUint::from_str_radix(balance, 10).unwrap())
.collect();

let entry: Entry<N_CURRENCIES> = Entry::new(username, balances.try_into().unwrap()).unwrap();
let entry: Entry<N_CURRENCIES> = Entry::new(username, balances.try_into().unwrap());

// Convert Fp to U256
let hash_str = format!("{:?}", entry.compute_leaf().hash);
Expand Down
2 changes: 1 addition & 1 deletion zk_prover/src/circuits/merkle_sum_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ where
// Assign the entry username to the witness
let username = self.assign_value_to_witness(
layouter.namespace(|| "assign entry username"),
big_uint_to_fp(self.entry.username_as_big_uint()),
big_uint_to_fp(&self.entry.username_as_big_uint()),
"entry username",
config.advices[0],
)?;
Expand Down
3 changes: 1 addition & 2 deletions zk_prover/src/circuits/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ mod test {
let invalid_leaf_balances = [1000.to_biguint().unwrap(), 1000.to_biguint().unwrap()];

// invalidate user entry
let invalid_entry =
Entry::new(circuit.entry.username().to_string(), invalid_leaf_balances).unwrap();
let invalid_entry = Entry::new(circuit.entry.username().to_string(), invalid_leaf_balances);

circuit.entry = invalid_entry;

Expand Down
6 changes: 3 additions & 3 deletions zk_prover/src/merkle_sum_tree/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ pub struct Entry<const N_CURRENCIES: usize> {
}

impl<const N_CURRENCIES: usize> Entry<N_CURRENCIES> {
pub fn new(username: String, balances: [BigUint; N_CURRENCIES]) -> Result<Self, &'static str> {
pub fn new(username: String, balances: [BigUint; N_CURRENCIES]) -> Self {
// Security Assumptions:
// Using `keccak256` for `hashed_username` ensures high collision resistance,
// appropriate for the assumed userbase of $2^{30}$.
// The `hashed_username` utilizes the full 256 bits produced by `keccak256`,
// but is adjusted to the field size through the Poseidon hash function's modulo operation.
let hashed_username: BigUint = BigUint::from_bytes_be(&keccak256(username.as_bytes()));
Ok(Entry {
Entry {
hashed_username,
balances,
username,
})
}
}

/// Returns a zero entry where the username is 0 and the balances are all 0
Expand Down
5 changes: 1 addition & 4 deletions zk_prover/src/merkle_sum_tree/mst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ impl<const N_CURRENCIES: usize, const N_BYTES: usize> MerkleSumTree<N_CURRENCIES
{
let depth = (entries.len() as f64).log2().ceil() as usize;

let mut nodes = vec![];

// Pad the entries with empty entries to make the number of entries equal to 2^depth
if entries.len() < 2usize.pow(depth as u32) {
entries.extend(vec![
Expand All @@ -123,7 +121,7 @@ impl<const N_CURRENCIES: usize, const N_BYTES: usize> MerkleSumTree<N_CURRENCIES

let leaves = build_leaves_from_entries(&entries);

let root = build_merkle_tree_from_leaves(&leaves, depth, &mut nodes)?;
let (root, nodes) = build_merkle_tree_from_leaves(&leaves, depth)?;

Ok(MerkleSumTree {
root,
Expand Down Expand Up @@ -202,7 +200,6 @@ impl<const N_CURRENCIES: usize, const N_BYTES: usize> MerkleSumTree<N_CURRENCIES
}

let root = self.nodes[self.depth][0].clone();

Ok(root)
}

Expand Down
45 changes: 8 additions & 37 deletions zk_prover/src/merkle_sum_tree/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ impl<const N_CURRENCIES: usize> Node<N_CURRENCIES> {
where
[usize; N_CURRENCIES + 1]: Sized,
{
let hash =
poseidon::Hash::<Fp, PoseidonSpec, ConstantLength<{ N_CURRENCIES + 1 }>, 2, 1>::init()
.hash(preimage.clone());
Node {
hash: Self::poseidon_hash_leaf(preimage[0], preimage[1..].try_into().unwrap()),
hash,
balances: preimage[1..].try_into().unwrap(),
}
}
Expand All @@ -71,44 +74,12 @@ impl<const N_CURRENCIES: usize> Node<N_CURRENCIES> {
where
[usize; N_CURRENCIES + 2]: Sized,
{
let hash =
poseidon::Hash::<Fp, PoseidonSpec, ConstantLength<{ N_CURRENCIES + 2 }>, 2, 1>::init()
.hash(preimage.clone());
Node {
hash: Self::poseidon_hash_middle(
preimage[0..N_CURRENCIES].try_into().unwrap(),
preimage[N_CURRENCIES],
preimage[N_CURRENCIES + 1],
),
hash,
balances: preimage[0..N_CURRENCIES].try_into().unwrap(),
}
}

fn poseidon_hash_middle(
balances_sum: [Fp; N_CURRENCIES],
hash_child_left: Fp,
hash_child_right: Fp,
) -> Fp
where
[usize; N_CURRENCIES + 2]: Sized,
{
let mut hash_inputs: [Fp; N_CURRENCIES + 2] = [Fp::zero(); N_CURRENCIES + 2];

hash_inputs[0..N_CURRENCIES].copy_from_slice(&balances_sum);
hash_inputs[N_CURRENCIES] = hash_child_left;
hash_inputs[N_CURRENCIES + 1] = hash_child_right;

poseidon::Hash::<Fp, PoseidonSpec, ConstantLength<{ N_CURRENCIES + 2 }>, 2, 1>::init()
.hash(hash_inputs)
}

fn poseidon_hash_leaf(username: Fp, balances: [Fp; N_CURRENCIES]) -> Fp
where
[usize; N_CURRENCIES + 1]: Sized,
{
let mut hash_inputs: [Fp; N_CURRENCIES + 1] = [Fp::zero(); N_CURRENCIES + 1];

hash_inputs[0] = username;
hash_inputs[1..N_CURRENCIES + 1].copy_from_slice(&balances);

poseidon::Hash::<Fp, PoseidonSpec, ConstantLength<{ N_CURRENCIES + 1 }>, 2, 1>::init()
.hash(hash_inputs)
}
}
3 changes: 1 addition & 2 deletions zk_prover/src/merkle_sum_tree/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ mod test {
let invalid_entry = Entry::new(
"AtwIxZHo".to_string(),
[35479.to_biguint().unwrap(), 35479.to_biguint().unwrap()],
)
.unwrap();
);
let invalid_entry = invalid_entry;
let mut proof_invalid_1 = proof.clone();
proof_invalid_1.entry = invalid_entry;
Expand Down
17 changes: 9 additions & 8 deletions zk_prover/src/merkle_sum_tree/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub trait Tree<const N_CURRENCIES: usize> {
let mut preimage = [Fp::zero(); N_CURRENCIES + 1];

// Add username to preimage
preimage[0] = big_uint_to_fp(entry.username_as_big_uint());
preimage[0] = big_uint_to_fp(&entry.username_as_big_uint());

// Add balances to preimage
for (i, balance) in preimage.iter_mut().enumerate().skip(1).take(N_CURRENCIES) {
Expand All @@ -97,6 +97,7 @@ pub trait Tree<const N_CURRENCIES: usize> {
if index >= nodes[0].len() {
return Err(Box::from("Index out of bounds"));
}
assert_eq!(nodes[0].len(), 2usize.pow(depth as u32));

let mut sibling_middle_node_hash_preimages = Vec::with_capacity(depth - 1);

Expand All @@ -111,7 +112,9 @@ pub trait Tree<const N_CURRENCIES: usize> {
let position = current_index % 2;
let sibling_index = current_index - position + (1 - position);

if sibling_index < nodes[level].len() && level != 0 {
// we asserted that the leaves vec length is a power of 2
// so the index shouldn't overflow the level's length
if level > 0 {
// Fetch hash preimage for sibling middle nodes
let sibling_node_preimage =
self.get_middle_node_hash_preimage(level, sibling_index)?;
Expand Down Expand Up @@ -152,14 +155,13 @@ pub trait Tree<const N_CURRENCIES: usize> {
if proof.path_indices[0] == 0.into() {
hash_preimage[N_CURRENCIES] = node.hash;
hash_preimage[N_CURRENCIES + 1] = sibling_leaf_node.hash;
node = Node::middle_node_from_preimage(&hash_preimage);
} else {
hash_preimage[N_CURRENCIES] = sibling_leaf_node.hash;
hash_preimage[N_CURRENCIES + 1] = node.hash;
node = Node::middle_node_from_preimage(&hash_preimage);
}
node = Node::middle_node_from_preimage(&hash_preimage);

for i in 1..proof.path_indices.len() {
for (i, path_index) in proof.path_indices.iter().enumerate().skip(1) {
let sibling_node = Node::<N_CURRENCIES>::middle_node_from_preimage(
&proof.sibling_middle_node_hash_preimages[i - 1],
);
Expand All @@ -169,15 +171,14 @@ pub trait Tree<const N_CURRENCIES: usize> {
*balance = node.balances[i] + sibling_node.balances[i];
}

if proof.path_indices[i] == 0.into() {
if *path_index == 0.into() {
hash_preimage[N_CURRENCIES] = node.hash;
hash_preimage[N_CURRENCIES + 1] = sibling_node.hash;
node = Node::middle_node_from_preimage(&hash_preimage);
} else {
hash_preimage[N_CURRENCIES] = sibling_node.hash;
hash_preimage[N_CURRENCIES + 1] = node.hash;
node = Node::middle_node_from_preimage(&hash_preimage);
}
node = Node::middle_node_from_preimage(&hash_preimage);
}

proof.root.hash == node.hash && proof.root.balances == node.balances
Expand Down
45 changes: 11 additions & 34 deletions zk_prover/src/merkle_sum_tree/utils/build_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,25 @@ use rayon::prelude::*;
pub fn build_merkle_tree_from_leaves<const N_CURRENCIES: usize>(
leaves: &[Node<N_CURRENCIES>],
depth: usize,
nodes: &mut Vec<Vec<Node<N_CURRENCIES>>>,
) -> Result<Node<N_CURRENCIES>, Box<dyn std::error::Error>>
) -> Result<(Node<N_CURRENCIES>, Vec<Vec<Node<N_CURRENCIES>>>), Box<dyn std::error::Error>>
where
[usize; N_CURRENCIES + 1]: Sized,
[usize; N_CURRENCIES + 2]: Sized,
{
let n = leaves.len();

let mut tree: Vec<Vec<Node<N_CURRENCIES>>> = Vec::with_capacity(depth + 1);

tree.push(vec![
Node {
hash: Fp::from(0),
balances: [Fp::from(0); N_CURRENCIES]
};
n
]);

for _ in 1..=depth {
let previous_level = tree.last().unwrap();
let nodes_in_level = (previous_level.len() + 1) / 2;
// the size of a leaf layer must be a power of 2
// if not, the `leaves` Vec should be completed with "zero entries" until a power of 2
assert_eq!(leaves.len(), 2usize.pow(depth as u32));

tree.push(vec![
Node {
hash: Fp::from(0),
balances: [Fp::from(0); N_CURRENCIES]
};
nodes_in_level
]);
}

for (index, leaf) in leaves.iter().enumerate() {
tree[0][index] = leaf.clone();
}
tree.push(leaves.to_vec());

for level in 1..=depth {
build_middle_level(level, &mut tree)
}

let root = tree[depth][0].clone();
*nodes = tree;
Ok(root)
Ok((root, tree))
}

pub fn build_leaves_from_entries<const N_CURRENCIES: usize>(
Expand Down Expand Up @@ -74,8 +51,10 @@ where
leaves
}

fn build_middle_level<const N_CURRENCIES: usize>(level: usize, tree: &mut [Vec<Node<N_CURRENCIES>>])
where
fn build_middle_level<const N_CURRENCIES: usize>(
level: usize,
tree: &mut Vec<Vec<Node<N_CURRENCIES>>>,
) where
[usize; N_CURRENCIES + 2]: Sized,
{
let results: Vec<Node<N_CURRENCIES>> = (0..tree[level - 1].len())
Expand All @@ -95,7 +74,5 @@ where
})
.collect();

for (index, new_node) in results.into_iter().enumerate() {
tree[level][index] = new_node;
}
tree.push(results);
}
3 changes: 2 additions & 1 deletion zk_prover/src/merkle_sum_tree/utils/csv_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ pub fn parse_csv_to_entries<P: AsRef<Path>, const N_CURRENCIES: usize, const N_B
balances_big_int.push(balance);
}

let entry = Entry::new(username, balances_big_int.try_into().unwrap())?;
let entry = Entry::new(username, balances_big_int.try_into().unwrap());

entries.push(entry);
}

Expand Down

0 comments on commit 23587ad

Please sign in to comment.