From 4412e100469199bb548f314aa2d03450963a877d Mon Sep 17 00:00:00 2001 From: "enrico.eth" <85900164+enricobottazzi@users.noreply.github.com> Date: Mon, 13 Nov 2023 17:09:32 +0100 Subject: [PATCH] feat: add `aggregation_merkle_sum_tree` (#2) --- Cargo.lock | 15 +- Cargo.toml | 3 +- bin/mini_tree_server.rs | 85 +++++----- src/aggregation_merkle_sum_tree.rs | 261 +++++++++++++++++++++++++++++ src/data/entry_16_1.csv | 17 ++ src/data/entry_16_2.csv | 17 ++ src/data/entry_16_3.csv | 17 ++ src/lib.rs | 3 + src/main.rs | 3 - 9 files changed, 368 insertions(+), 53 deletions(-) create mode 100644 src/aggregation_merkle_sum_tree.rs create mode 100644 src/data/entry_16_1.csv create mode 100644 src/data/entry_16_2.csv create mode 100644 src/data/entry_16_3.csv create mode 100644 src/lib.rs delete mode 100644 src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 51be32a..f432263 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -762,9 +762,9 @@ checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" [[package]] name = "crypto-bigint" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "740fe28e594155f10cfc383984cbefd529d7396050557148f79cb0f621204124" +checksum = "28f85c3514d2a6e64160359b45a3918c3b4178bcbf4ae5d03ab2d02e521c479a" dependencies = [ "generic-array", "rand_core 0.6.4", @@ -3059,9 +3059,9 @@ dependencies = [ [[package]] name = "proptest" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c003ac8c77cb07bb74f5f198bce836a689bcd5a42574612bf14d17bfd08c20e" +checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf" dependencies = [ "bitflags 2.4.1", "lazy_static", @@ -3069,7 +3069,7 @@ dependencies = [ "rand 0.8.5", "rand_chacha", "rand_xorshift", - "regex-syntax 0.7.5", + "regex-syntax 0.8.2", "unarray", ] @@ -3933,6 +3933,7 @@ dependencies = [ "const_env", "halo2_proofs", "num-bigint 0.4.4", + "rand 0.8.5", "serde", "serde_json", "summa-backend", @@ -3942,7 +3943,7 @@ dependencies = [ [[package]] name = "summa-backend" version = "0.1.0" -source = "git+https://github.com/summa-dev/summa-solvency?branch=v1-for-summa-aggregation#b5ef9f2876b0612c4552863c2b236bd21584b537" +source = "git+https://github.com/summa-dev/summa-solvency?branch=v1-improvements-and-consolidation#8ab0b07587ced37d341266a73c187bb22e03560e" dependencies = [ "base64 0.13.1", "bincode", @@ -3962,7 +3963,7 @@ dependencies = [ [[package]] name = "summa-solvency" version = "0.1.0" -source = "git+https://github.com/summa-dev/summa-solvency?branch=v1-for-summa-aggregation#b5ef9f2876b0612c4552863c2b236bd21584b537" +source = "git+https://github.com/summa-dev/summa-solvency?branch=v1-improvements-and-consolidation#8ab0b07587ced37d341266a73c187bb22e03560e" dependencies = [ "ark-std", "csv", diff --git a/Cargo.toml b/Cargo.toml index 264f92f..7883a98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,9 +11,10 @@ const_env = "0.1.2" num-bigint = "0.4.4" serde = { version = "1.0.192", features = ["derive"] } serde_json = "1.0.108" -summa-backend = { git = "https://github.com/summa-dev/summa-solvency", branch = "v1-for-summa-aggregation", version = "0.1.0" } +summa-backend = { git = "https://github.com/summa-dev/summa-solvency", branch = "v1-improvements-and-consolidation", version = "0.1.0" } halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v2023_04_20"} tokio = { version = "1.34.0", features = ["full"] } +rand = "0.8" [[bin]] name = "mini-tree-server" diff --git a/bin/mini_tree_server.rs b/bin/mini_tree_server.rs index 8738eab..709e4d2 100644 --- a/bin/mini_tree_server.rs +++ b/bin/mini_tree_server.rs @@ -1,22 +1,15 @@ +use axum::{extract::Json, http::StatusCode, response::IntoResponse, routing::post, Router}; use const_env::from_env; -use axum::{ - extract::Json, - response::IntoResponse, - routing::post, - Router, - http::StatusCode, -}; +use num_bigint::BigUint; use std::net::SocketAddr; -use num_bigint::{BigUint}; +use summa_backend::merkle_sum_tree::{Entry, MerkleSumTree, Node, Tree}; -use summa_backend::{MerkleSumTree, Entry, Node, Tree}; +use serde::{Deserialize, Serialize}; -use serde::{Serialize, Deserialize}; - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct JsonNode { - pub hash: String, - pub balances: Vec, + pub hash: String, + pub balances: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -42,11 +35,10 @@ const N_BYTES: usize = 14; #[tokio::main] async fn main() { // Define the app with a route - let app = Router::new() - .route("/", post(create_mst)); + let app = Router::new().route("/", post(create_mst)); // Define the address to serve on - let addr = SocketAddr::from(([0, 0, 0, 0], 4000)); // TODO: assign ports from env variable + let addr = SocketAddr::from(([0, 0, 0, 0], 4000)); // TODO: assign ports from env variable // Start the server axum::Server::bind(&addr) @@ -56,40 +48,49 @@ async fn main() { } fn convert_node_to_json(node: &Node) -> JsonNode { - JsonNode { - hash: format!("{:?}", node.hash), - balances: node.balances.iter().map(|b| format!("{:?}", b)).collect(), - } + JsonNode { + hash: format!("{:?}", node.hash), + balances: node.balances.iter().map(|b| format!("{:?}", b)).collect(), + } } -async fn create_mst(Json(json_entries): Json>) -> Result)> { +async fn create_mst( + Json(json_entries): Json>, +) -> Result)> { // Convert `JsonEntry` -> `Entry` - let entries = json_entries.iter().map(|entry| { - let mut balances: [BigUint; N_ASSETS] = std::array::from_fn(|_| BigUint::from(0u32)); - entry.balances.iter().enumerate().for_each(|(i, balance)| { - balances[i] = balance.parse::().unwrap(); - }); - Entry::new(entry.username.clone(), balances).unwrap() - }).collect::>>(); + let entries = json_entries + .iter() + .map(|entry| { + let mut balances: [BigUint; N_ASSETS] = std::array::from_fn(|_| BigUint::from(0u32)); + entry.balances.iter().enumerate().for_each(|(i, balance)| { + balances[i] = balance.parse::().unwrap(); + }); + Entry::new(entry.username.clone(), balances).unwrap() + }) + .collect::>>(); // Create `MerkleSumTree` from `parsed_entries` let tree = MerkleSumTree::::from_entries(entries, false).unwrap(); // Convert `MerkleSumTree` to `JsonMerkleSumTree` let json_tree = JsonMerkleSumTree { - root: convert_node_to_json(&tree.root()), - nodes: tree.nodes().iter().map(|layer| { - layer.iter().map(convert_node_to_json).collect() - }).collect(), - depth: tree.depth().clone(), - entries: tree.entries().iter().map(|entry| { - JsonEntry { - balances: entry.balances().iter().map(|b| b.to_string()).collect(), - username: entry.username().to_string(), - } - }).collect(), - is_sorted: false, // TODO: assign from request data - }; - + root: convert_node_to_json(&tree.root()), + nodes: tree + .nodes() + .iter() + .map(|layer| layer.iter().map(convert_node_to_json).collect()) + .collect(), + depth: tree.depth().clone(), + entries: tree + .entries() + .iter() + .map(|entry| JsonEntry { + balances: entry.balances().iter().map(|b| b.to_string()).collect(), + username: entry.username().to_string(), + }) + .collect(), + is_sorted: false, // TODO: assign from request data + }; + Ok((StatusCode::OK, Json(json_tree))) } diff --git a/src/aggregation_merkle_sum_tree.rs b/src/aggregation_merkle_sum_tree.rs new file mode 100644 index 0000000..5eacc4c --- /dev/null +++ b/src/aggregation_merkle_sum_tree.rs @@ -0,0 +1,261 @@ +use halo2_proofs::halo2curves::bn256::Fr as Fp; +use num_bigint::BigUint; +use summa_backend::merkle_sum_tree::utils::{build_merkle_tree_from_leaves, fp_to_big_uint}; +use summa_backend::merkle_sum_tree::{Entry, MerkleProof, MerkleSumTree, Node, Tree}; + +/// Aggregation Merkle Sum Tree Data Structure. +/// +/// Starting from a set of "mini" Merkle Sum Trees of equal depth, N_ASSETS and N_BYTES, the Aggregation Merkle Sum Tree inherits the properties of a Merkle Sum Tree and adds the following: +/// * Each Leaf of the Aggregation Merkle Sum Tree is the root of a "mini" Merkle Sum Tree made of `hash` and `balances` +/// +/// # Type Parameters +/// +/// * `N_ASSETS`: The number of assets for each user account +/// * `N_BYTES`: Range in which each node balance should lie +#[derive(Debug, Clone)] +pub struct AggregationMerkleSumTree { + root: Node, + nodes: Vec>>, + depth: usize, + mini_trees: Vec>, +} + +impl Tree + for AggregationMerkleSumTree +{ + fn root(&self) -> &Node { + &self.root + } + + fn depth(&self) -> &usize { + &self.depth + } + + fn leaves(&self) -> &[Node] { + &self.nodes[0] + } + + fn nodes(&self) -> &[Vec>] { + &self.nodes + } + + fn get_entry(&self, user_index: usize) -> &Entry { + let (mini_tree_index, entry_index) = self.get_entry_location(user_index); + + // Retrieve the mini tree + let mini_tree = &self.mini_trees[mini_tree_index]; + + // Retrieve the entry within the mini tree + mini_tree.get_entry(entry_index) + } + + fn generate_proof(&self, index: usize) -> Result, &'static str> { + let (mini_tree_index, entry_index) = self.get_entry_location(index); + + // Retrieve the mini tree + let mini_tree = &self.mini_trees[mini_tree_index]; + + // Build the partial proof, namely from the leaf to the root of the mini tree + let mut partial_proof = mini_tree.generate_proof(entry_index)?; + + // Build the rest of the proof (top_proof), namely from the root of the mini tree to the root of the aggregation tree + let mut current_index = mini_tree_index; + + let mut sibling_hashes = vec![Fp::from(0); self.depth]; + let mut sibling_sums = vec![[Fp::from(0); N_ASSETS]; self.depth]; + let mut path_indices = vec![Fp::from(0); self.depth]; + + for level in 0..self.depth { + let position = current_index % 2; + let level_start_index = current_index - position; + let level_end_index = level_start_index + 2; + + path_indices[level] = Fp::from(position as u64); + + for i in level_start_index..level_end_index { + if i != current_index { + sibling_hashes[level] = self.nodes[level][i].hash; + sibling_sums[level] = self.nodes[level][i].balances; + } + } + current_index /= 2; + } + + // append the top_proof to the partial_proof + partial_proof.sibling_hashes.extend(sibling_hashes); + partial_proof.sibling_sums.extend(sibling_sums); + partial_proof.path_indices.extend(path_indices); + + // replace the root of the partial proof with the root of the aggregation tree + partial_proof.root = self.root.clone(); + + Ok(partial_proof) + } +} + +impl AggregationMerkleSumTree { + /// Builds a AggregationMerkleSumTree from a set of mini MerkleSumTrees + /// The leaves of the AggregationMerkleSumTree are the roots of the mini MerkleSumTrees + pub fn new( + mini_trees: Vec>, + ) -> Result> + where + [usize; N_ASSETS + 1]: Sized, + [usize; 2 * (1 + N_ASSETS)]: Sized, + { + // assert that all mini trees have the same depth + let depth = mini_trees[0].depth(); + assert!(mini_trees.iter().all(|x| x.depth() == depth)); + + Self::build_tree(mini_trees) + } + + fn build_tree( + mini_trees: Vec>, + ) -> Result, Box> + where + [usize; N_ASSETS + 1]: Sized, + [usize; 2 * (1 + N_ASSETS)]: Sized, + { + // extract all the roots of the mini trees + let roots = mini_trees + .iter() + .map(|x| x.root().clone()) + .collect::>>(); + + let depth = (roots.len() as f64).log2().ceil() as usize; + + // Calculate the accumulated balances for each asset + let mut balances_acc: Vec = vec![Fp::from(0); N_ASSETS]; + + for root in &roots { + for (i, balance) in root.balances.iter().enumerate() { + balances_acc[i] += *balance; + } + } + + // Iterate through the balance accumulator and throw error if any balance is not in range 0, 2 ^ (8 * N_BYTES): + for balance in &balances_acc { + // transform the balance to a BigUint + let balance_big_uint = fp_to_big_uint(*balance); + + if balance_big_uint >= BigUint::from(2_usize).pow(8 * N_BYTES as u32) { + return Err( + "Accumulated balance is not in the expected range, proof generation will fail!" + .into(), + ); + } + } + + let mut nodes = vec![]; + let root = build_merkle_tree_from_leaves(&roots, depth, &mut nodes)?; + + Ok(AggregationMerkleSumTree { + root, + nodes, + depth, + mini_trees, + }) + } + + pub fn mini_tree(&self, tree_index: usize) -> &MerkleSumTree { + &self.mini_trees[tree_index] + } + + /// starting from a user_index, returns the index of the mini tree in which the entry is located and the index of the entry within the mini tree + fn get_entry_location(&self, user_index: usize) -> (usize, usize) { + let entries_per_mini_tree = 1 << self.mini_trees[0].depth(); + + // Calculate which mini tree the entry is in + let mini_tree_index = user_index / entries_per_mini_tree; + + // Calculate the index within the mini tree + let entry_index = user_index % entries_per_mini_tree; + + (mini_tree_index, entry_index) + } +} + +#[cfg(test)] +mod test { + use num_bigint::ToBigUint; + use summa_backend::merkle_sum_tree::{MerkleSumTree, Tree}; + + use crate::aggregation_merkle_sum_tree::AggregationMerkleSumTree; + + const N_ASSETS: usize = 2; + const N_BYTES: usize = 8; + + #[test] + fn test_aggregation_mst() { + // create new mini merkle sum tree + let mini_tree_1 = + MerkleSumTree::::new("src/data/entry_16_1.csv").unwrap(); + + let mini_tree_2 = + MerkleSumTree::::new("src/data/entry_16_2.csv").unwrap(); + + let aggregation_mst = AggregationMerkleSumTree::::new(vec![ + mini_tree_1.clone(), + mini_tree_2.clone(), + ]) + .unwrap(); + + // get root + let root = aggregation_mst.root(); + + // expect root hash to be different than 0 + assert!(root.hash != 0.into()); + // expect balance to match the sum of all entries + assert!(root.balances == [(556862 * 2).into(), (556862 * 2).into()]); + + // expect depth to be equal to merkle_sum_tree_1.depth (= merkle_sum_tree_2.depth) + 1 + let depth = aggregation_mst.depth(); + + assert!(*depth == 1); + + let mut index = rand::random::() % 32; + + // the entry fetched from the aggregation tree should be the same as the entry fetched from the corresponding mini tree + let entry = aggregation_mst.get_entry(index); + + if index < 16 { + assert!(entry.username() == mini_tree_1.get_entry(index).username()); + assert!(entry.balances() == mini_tree_1.get_entry(index).balances()); + } else { + index -= 16; + assert!(entry.username() == mini_tree_2.get_entry(index).username()); + assert!(entry.balances() == mini_tree_2.get_entry(index).balances()); + } + + // Generate proof for the entry + let proof = aggregation_mst.generate_proof(index).unwrap(); + + // verify proof + assert!(aggregation_mst.verify_proof(&proof)); + } + + #[test] + fn test_aggregation_mst_overflow() { + // create new mini merkle sum trees. The accumulated balance for each mini tree is in the expected range + // note that the accumulated balance of the tree generated from entry_16_3 is just in the expected range for 1 unit + let merkle_sum_tree_1 = + MerkleSumTree::::new("src/data/entry_16_1.csv").unwrap(); + + let merkle_sum_tree_2 = + MerkleSumTree::::new("src/data/entry_16_.csv").unwrap(); + + // When creating the aggregation merkle sum tree, the accumulated balance of the two mini trees is not in the expected range, an error is thrown + let result = AggregationMerkleSumTree::::new(vec![ + merkle_sum_tree_1, + merkle_sum_tree_2.clone(), + ]); + + if let Err(e) = result { + assert_eq!( + e.to_string(), + "Accumulated balance is not in the expected range, proof generation will fail!" + ); + } + } +} diff --git a/src/data/entry_16_1.csv b/src/data/entry_16_1.csv new file mode 100644 index 0000000..228b1db --- /dev/null +++ b/src/data/entry_16_1.csv @@ -0,0 +1,17 @@ +username;balances +dxGaEAii;11888,41163 +MBlfbBGI;67823,18651 +lAhWlEWZ;18651,2087 +nuZweYtO;22073,55683 +gbdSwiuY;34897,83296 +RZNneNuP;83296,16881 +YsscHXkp;31699,35479 +RkLzkDun;2087,79731 +HlQlnEYI;30605,11888 +RqkZOFYe;16881,14874 +NjCSRAfD;41163,67823 +pHniJMQY;14874,22073 +dOGIMzKR;10032,10032 +HfMDmNLp;55683,34897 +xPLKzCBl;79731,30605 +AtwIxZHo;35479,31699 diff --git a/src/data/entry_16_2.csv b/src/data/entry_16_2.csv new file mode 100644 index 0000000..b2183a2 --- /dev/null +++ b/src/data/entry_16_2.csv @@ -0,0 +1,17 @@ +username;balances +aaGaEAaa;11888,41163 +bblfbBGI;67823,18651 +cchWlEWZ;18651,2087 +ddZweYtO;22073,55683 +eedSwiuY;34897,83296 +ffNneNuP;83296,16881 +ggscHXkp;31699,35479 +hhLzkDun;2087,79731 +iiQlnEYI;30605,11888 +llkZOFYe;16881,14874 +mmCSRAfD;41163,67823 +nnniJMQY;14874,22073 +ooGIMzKR;10032,10032 +ppMDmNLp;55683,34897 +qqLKzCBl;79731,30605 +rrwIxZHo;35479,31699 \ No newline at end of file diff --git a/src/data/entry_16_3.csv b/src/data/entry_16_3.csv new file mode 100644 index 0000000..49dc935 --- /dev/null +++ b/src/data/entry_16_3.csv @@ -0,0 +1,17 @@ +username;balances +dxGaEAii;18446744073709551615,0 +MBlfbBGI;0,18446744073709551615 +lAhWlEWZ;0,0 +nuZweYtO;0,0 +gbdSwiuY;0,0 +RZNneNuP;0,0 +YsscHXkp;0,0 +RkLzkDun;0,0 +HlQlnEYI;0,0 +RqkZOFYe;0,0 +NjCSRAfD;0,0 +pHniJMQY;0,0 +dOGIMzKR;0,0 +HfMDmNLp;0,0 +xPLKzCBl;0,0 +AtwIxZHo;0,0 diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..cc00c82 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,3 @@ +#![feature(generic_const_exprs)] + +pub mod aggregation_merkle_sum_tree; diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 9a42770..0000000 --- a/src/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("Welcome, Orchestrator!"); -}