From fd8ced0d72c19f307e4fe07df9f71d7d612bb48c Mon Sep 17 00:00:00 2001 From: Dimitris Mouris Date: Mon, 5 Jun 2023 13:26:25 -0400 Subject: [PATCH] Fix LUTs in helm.rs --- Cargo.toml | 2 +- README.md | 16 ++++ src/bin/helm.rs | 56 ++++------- src/circuit.rs | 213 ++++++++++++++---------------------------- src/gates.rs | 26 ++---- src/verilog_parser.rs | 25 ++--- 6 files changed, 120 insertions(+), 218 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2cedf1e..9220684 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "helm" -description = "HELM: Homomorphic Evaluation with Lookup table Memoization" +description = "HELM: Homomorphic Evaluation with EDA-driven Logic Minimization" version = "0.1.0" edition = "2021" authors = ["Dimitris Mouris ", "Charles Gouert "] diff --git a/README.md b/README.md index b633656..8c47414 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,18 @@ git clone --recurse-submodules git@github.com:TrustworthyComputing/helm.git ``` ### Build & Run + +Compile and run the tests: ```shell cargo build --release cargo test --release +``` + +HELM has two modes: "gates"-mode and "LUTs"-mode. HELM automatically detects if +a LUTs or a gates circuit has been provided as input. Below are two examples: + +Example in "gates"-mode: +```shell cargo run --bin helm --release -- \ --input ./hdl-benchmarks/processed-netlists/s27.v cargo run --bin helm --release -- \ @@ -25,6 +34,13 @@ cargo run --bin helm --release -- \ --wires ./hdl-benchmarks/test-cases/2-bit-adder.inputs.csv ``` +Example in "LUTs"-mode: +```shell +cargo run --bin helm --release -- \ + --input ./hdl-benchmarks/processed-netlists/8-bit-adder-lut-3-1.v \ + --wires hdl-benchmarks/test-cases/8-bit-adder.inputs.csv +``` + ### Example of an ISCAS'85 circuit If a circuit is in the [netlists](./hdl-benchmarks/netlists/) directory but not in the [processed-netlists](./hdl-benchmarks/processed-netlists/), run the diff --git a/src/bin/helm.rs b/src/bin/helm.rs index 7f9de9d..e812e76 100644 --- a/src/bin/helm.rs +++ b/src/bin/helm.rs @@ -3,24 +3,11 @@ use debug_print::debug_println; use helm::{ascii, circuit, circuit::EvalCircuit, verilog_parser}; use std::{collections::HashMap, time::Instant}; use termion::color; -use tfhe::{ - boolean::prelude::*, - integer::{ - wopbs::WopbsKey as WopbsKeyInt, ClientKey as ClientKeyInt, ServerKey as ServerKeyInt, - }, - shortint::{ - parameters::{ - parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_1_CARRY_1, - PARAM_MESSAGE_1_CARRY_1, - PARAM_MESSAGE_3_CARRY_0, - }, - wopbs::WopbsKey as WopbsKeyShortInt, - }, -}; +use tfhe::{boolean::prelude::*, shortint::parameters::PARAM_MESSAGE_3_CARRY_0}; fn parse_args() -> (String, usize, bool, HashMap) { let matches = Command::new("HELM") - .about("HELM: Homomorphic Evaluation with Lookup table Memoization") + .about("HELM: Homomorphic Evaluation with EDA-driven Logic Minimization") .arg( Arg::new("input") .long("input") @@ -75,12 +62,7 @@ fn parse_args() -> (String, usize, bool, HashMap) { } }; - ( - file_name.to_string(), - num_cycles, - verbose, - input_wire_map, - ) + (file_name.to_string(), num_cycles, verbose, input_wire_map) } fn main() { @@ -133,6 +115,12 @@ fn main() { // Encrypted Evaluation if !has_luts { + println!( + "{} -- Gates mode -- {}", + color::Fg(color::LightYellow), + color::Fg(color::Reset) + ); + // Gate mode let mut start = Instant::now(); let (client_key, server_key) = gen_keys(); @@ -168,28 +156,18 @@ fn main() { start.elapsed().as_secs_f64() ); } else { + println!( + "{} -- LUTs mode -- {}", + color::Fg(color::LightYellow), + color::Fg(color::Reset) + ); + // LUT mode let mut start = Instant::now(); - let (client_key_shortint, server_key_shortint) = - tfhe::shortint::gen_keys(PARAM_MESSAGE_1_CARRY_1); // single bit ctxt - let client_key = ClientKeyInt::from(client_key_shortint.clone()); - let server_key = ServerKeyInt::from_shortint(&client_key, server_key_shortint.clone()); - let wopbs_key_shortint = WopbsKeyShortInt::new_wopbs_key( - &client_key_shortint, - &server_key_shortint, - &&WOPBS_PARAM_MESSAGE_1_CARRY_1, - ); - let wopbs_key = WopbsKeyInt::from(wopbs_key_shortint.clone()); + let (client_key, server_key) = tfhe::shortint::gen_keys(PARAM_MESSAGE_3_CARRY_0); // single bit ctxt + let mut circuit = circuit::LutCircuit::new(client_key, server_key, circuit_ptxt); println!("KeyGen done in {} seconds.", start.elapsed().as_secs_f64()); - let mut circuit = circuit::HighPrecisionLutCircuit::new( - wopbs_key_shortint, - wopbs_key, - client_key.clone(), - server_key, - circuit_ptxt, - ); - // Client encrypts their inputs start = Instant::now(); let mut enc_wire_map = diff --git a/src/circuit.rs b/src/circuit.rs index d875c54..3651654 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -10,25 +10,22 @@ use termion::color; use tfhe::{ boolean::prelude::*, integer::{ - wopbs::WopbsKey as WopbsKeyInt, ClientKey as ClientKeyInt, - ServerKey as ServerKeyInt, + wopbs::WopbsKey as WopbsKeyInt, ClientKey as ClientKeyInt, ServerKey as ServerKeyInt, }, shortint::{ ciphertext::{CiphertextBase, KeyswitchBootstrap}, wopbs::WopbsKey as WopbsKeyShortInt, - ClientKey as ClientKeyShortInt, - ServerKey as ServerKeyShortInt, + ClientKey as ClientKeyShortInt, ServerKey as ServerKeyShortInt, }, }; -#[cfg(test)] -use rand::Rng; #[cfg(test)] use debug_print::debug_println; #[cfg(test)] +use rand::Rng; +#[cfg(test)] use tfhe::shortint::parameters::{ - PARAM_MESSAGE_1_CARRY_1, - parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_1_CARRY_1, + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_3_CARRY_0, }; @@ -71,8 +68,7 @@ pub struct LutCircuit<'a> { server_key: ServerKeyShortInt, } -// Note: this is not used as there is no easy way to get LUTs with more than -// six inputs. +// Note: this is not used as there is no easy way to get LUTs with more than six inputs. pub struct HighPrecisionLutCircuit<'a> { circuit: Circuit<'a>, wopbs_shortkey: WopbsKeyShortInt, @@ -98,7 +94,7 @@ impl<'a> Circuit<'a> { } } -// TODO: sequential + // TODO: sequential // Topologically sort the gates pub fn sort_circuit(&mut self) { assert!(!self.gates.is_empty()); @@ -155,8 +151,12 @@ impl<'a> Circuit<'a> { for gate in &mut self.ordered_gates { if gate.get_gate_type() == GateType::Dff { match self.level_map.entry(std::usize::MAX) { - Entry::Vacant(e) => { e.insert(vec![gate.clone()]); } - Entry::Occupied(mut e) => { e.get_mut().push(gate.clone()); } + Entry::Vacant(e) => { + e.insert(vec![gate.clone()]); + } + Entry::Occupied(mut e) => { + e.get_mut().push(gate.clone()); + } } gate.set_level(std::usize::MAX); continue; @@ -173,8 +173,12 @@ impl<'a> Circuit<'a> { gate.set_level(depth); match self.level_map.entry(depth) { - Entry::Vacant(e) => { e.insert(vec![gate.clone()]); } - Entry::Occupied(mut e) => { e.get_mut().push(gate.clone()); } + Entry::Vacant(e) => { + e.insert(vec![gate.clone()]); + } + Entry::Occupied(mut e) => { + e.get_mut().push(gate.clone()); + } } wire_levels.insert(gate.get_output_wire(), depth); @@ -191,7 +195,7 @@ impl<'a> Circuit<'a> { } } - // Remove all the gates after the compute levels is done. Use + // Remove all the gates after the compute levels is done. Use // self.level_map from now on. self.ordered_gates.clear(); } @@ -265,28 +269,23 @@ impl<'a> Circuit<'a> { // Get the corresponding index in the wires array let output_index = key_to_index[&gate.get_output_wire()]; - + // Update the value of the corresponding key - *eval_values[output_index].write() + *eval_values[output_index] + .write() .expect("Failed to acquire write lock") = output_value; }); } key_to_index .iter() - .map(|(&key, &index)| - (key.to_string(), *eval_values[index].read().unwrap()) - ) + .map(|(&key, &index)| (key.to_string(), *eval_values[index].read().unwrap())) .collect::>() } } impl<'a> GateCircuit<'a> { - pub fn new( - client_key: ClientKey, - server_key: ServerKey, - circuit: Circuit - ) -> GateCircuit { + pub fn new(client_key: ClientKey, server_key: ServerKey, circuit: Circuit) -> GateCircuit { GateCircuit { client_key, server_key, @@ -299,7 +298,7 @@ impl<'a> LutCircuit<'a> { pub fn new( client_key: ClientKeyShortInt, server_key: ServerKeyShortInt, - circuit: Circuit + circuit: Circuit, ) -> LutCircuit { LutCircuit { client_key, @@ -337,16 +336,12 @@ impl<'a> EvalCircuit for GateCircuit<'a> { ) -> HashMap { let mut enc_wire_map = HashMap::::new(); for (wire, &value) in wire_map_im { - enc_wire_map.insert( - wire.to_string(), self.client_key.encrypt(value) - ); + enc_wire_map.insert(wire.to_string(), self.client_key.encrypt(value)); } for input_wire in self.circuit.input_wires { // if no inputs are provided, initialize it to false if input_wire_map.is_empty() { - enc_wire_map.insert( - input_wire.to_string(), self.client_key.encrypt(false) - ); + enc_wire_map.insert(input_wire.to_string(), self.client_key.encrypt(false)); } else if !input_wire_map.contains_key(input_wire) { panic!("\n Input wire \"{}\" not in input wires!", input_wire); } else { @@ -357,9 +352,7 @@ impl<'a> EvalCircuit for GateCircuit<'a> { } } for wire in self.circuit.dff_outputs { - enc_wire_map.insert( - wire.to_string(), self.client_key.encrypt(false) - ); + enc_wire_map.insert(wire.to_string(), self.client_key.encrypt(false)); } enc_wire_map @@ -378,9 +371,7 @@ impl<'a> EvalCircuit for GateCircuit<'a> { let (key_to_index, eval_values): (HashMap<_, _>, Vec<_>) = enc_wire_map .iter() .enumerate() - .map(|(i, (key, value))| - ((key, i), Arc::new(RwLock::new(value.clone()))) - ) + .map(|(i, (key, value))| ((key, i), Arc::new(RwLock::new(value.clone())))) .unzip(); // For each level @@ -400,16 +391,14 @@ impl<'a> EvalCircuit for GateCircuit<'a> { // Get the corresponding index in the wires array let index = match key_to_index.get(input) { Some(&index) => index, - None => panic!( - "Input wire {} not in key_to_index map", input), + None => panic!("Input wire {} not in key_to_index map", input), }; // Read the value of the corresponding key eval_values[index].read().unwrap().clone() }) .collect(); - let output_value = gate.evaluate_encrypted( - &self.server_key, &input_values, cycle); + let output_value = gate.evaluate_encrypted(&self.server_key, &input_values, cycle); // Get the corresponding index in the wires array let output_index = key_to_index[&gate.get_output_wire()]; @@ -422,18 +411,11 @@ impl<'a> EvalCircuit for GateCircuit<'a> { key_to_index .iter() - .map(|(&key, &index)| - (key.to_string(), eval_values[index].read().unwrap().clone()) - ) + .map(|(&key, &index)| (key.to_string(), eval_values[index].read().unwrap().clone())) .collect() } - fn decrypt_outputs( - &self, - enc_wire_map: &HashMap, - verbose: bool - ) { + fn decrypt_outputs(&self, enc_wire_map: &HashMap, verbose: bool) { for (i, output_wire) in self.circuit.output_wires.iter().enumerate() { if i > 10 && !verbose { println!( @@ -463,23 +445,14 @@ impl<'a> EvalCircuit> for LutCircuit<'a> { ) -> HashMap { let mut enc_wire_map = HashMap::::new(); for (wire, &value) in wire_map_im { - enc_wire_map.insert( - wire.to_string(), - self.client_key.encrypt(value as u64), - ); + enc_wire_map.insert(wire.to_string(), self.client_key.encrypt(value as u64)); } for input_wire in self.circuit.input_wires { // if no inputs are provided, initialize it to false if input_wire_map.is_empty() { - enc_wire_map.insert( - input_wire.to_string(), - self.client_key.encrypt(0) - ); + enc_wire_map.insert(input_wire.to_string(), self.client_key.encrypt(0)); } else if !input_wire_map.contains_key(input_wire) { - panic!( - "\n Input wire \"{}\" not found in input wires!", - input_wire - ); + panic!("\n Input wire \"{}\" not found in input wires!", input_wire); } else { enc_wire_map.insert( input_wire.to_string(), @@ -488,10 +461,7 @@ impl<'a> EvalCircuit> for LutCircuit<'a> { } } for wire in self.circuit.dff_outputs { - enc_wire_map.insert( - wire.to_string(), - self.client_key.encrypt(0) - ); + enc_wire_map.insert(wire.to_string(), self.client_key.encrypt(0)); } enc_wire_map @@ -510,9 +480,7 @@ impl<'a> EvalCircuit> for LutCircuit<'a> { let (key_to_index, eval_values): (HashMap<_, _>, Vec<_>) = enc_wire_map .iter() .enumerate() - .map(|(i, (key, value))| - ((key, i), Arc::new(RwLock::new(value.clone()))) - ) + .map(|(i, (key, value))| ((key, i), Arc::new(RwLock::new(value.clone())))) .unzip(); // For each level @@ -532,47 +500,34 @@ impl<'a> EvalCircuit> for LutCircuit<'a> { // Get the corresponding index in the wires array let index = match key_to_index.get(input) { Some(&index) => index, - None => panic!( - "Input wire {} not in key_to_index map", - input - ), + None => panic!("Input wire {} not in key_to_index map", input), }; // Read the value of the corresponding key eval_values[index].read().unwrap().clone() }) .collect(); - let output_value = gate.evaluate_encrypted_lut( - &self.server_key, - &input_values, - cycle, - ); + let output_value = + gate.evaluate_encrypted_lut(&self.server_key, &input_values, cycle); // Get the corresponding index in the wires array let output_index = key_to_index[&gate.get_output_wire()]; - + // Update the value of the corresponding key - *eval_values[output_index].write() + *eval_values[output_index] + .write() .expect("Failed to acquire write lock") = output_value; - }); println!(" Evaluated gates in level [{}/{}]", level, total_levels); } key_to_index .iter() - .map(|(&key, &index)| - (key.to_string(), eval_values[index].read().unwrap().clone()) - ) + .map(|(&key, &index)| (key.to_string(), eval_values[index].read().unwrap().clone())) .collect() } - fn decrypt_outputs( - &self, - enc_wire_map: &HashMap, - verbose: bool - ) { + fn decrypt_outputs(&self, enc_wire_map: &HashMap, verbose: bool) { for (i, output_wire) in self.circuit.output_wires.iter().enumerate() { if i > 10 && !verbose { println!( @@ -610,15 +565,9 @@ impl<'a> EvalCircuit> for HighPrecisionLutCir for input_wire in self.circuit.input_wires { // if no inputs are provided, initialize it to false if input_wire_map.is_empty() { - enc_wire_map.insert( - input_wire.to_string(), - self.client_key.encrypt_one_block(0) - ); + enc_wire_map.insert(input_wire.to_string(), self.client_key.encrypt_one_block(0)); } else if !input_wire_map.contains_key(input_wire) { - panic!( - "\n Input wire \"{}\" not found in input wires!", - input_wire - ); + panic!("\n Input wire \"{}\" not found in input wires!", input_wire); } else { enc_wire_map.insert( input_wire.to_string(), @@ -628,10 +577,7 @@ impl<'a> EvalCircuit> for HighPrecisionLutCir } } for wire in self.circuit.dff_outputs { - enc_wire_map.insert( - wire.to_string(), - self.client_key.encrypt_one_block(0) - ); + enc_wire_map.insert(wire.to_string(), self.client_key.encrypt_one_block(0)); } enc_wire_map @@ -650,9 +596,7 @@ impl<'a> EvalCircuit> for HighPrecisionLutCir let (key_to_index, eval_values): (HashMap<_, _>, Vec<_>) = enc_wire_map .iter() .enumerate() - .map(|(i, (key, value))| - ((key, i), Arc::new(RwLock::new(value.clone()))) - ) + .map(|(i, (key, value))| ((key, i), Arc::new(RwLock::new(value.clone())))) .unzip(); // For each level @@ -672,10 +616,7 @@ impl<'a> EvalCircuit> for HighPrecisionLutCir // Get the corresponding index in the wires array let index = match key_to_index.get(input) { Some(&index) => index, - None => panic!( - "Input wire {} not in key_to_index map", - input - ), + None => panic!("Input wire {} not in key_to_index map", input), }; // Read the value of the corresponding key @@ -692,9 +633,10 @@ impl<'a> EvalCircuit> for HighPrecisionLutCir // Get the corresponding index in the wires array let output_index = key_to_index[&gate.get_output_wire()]; - + // Update the value of the corresponding key - *eval_values[output_index].write() + *eval_values[output_index] + .write() .expect("Failed to acquire write lock") = output_value; }); println!(" Evaluated gates in level [{}/{}]", level, total_levels); @@ -702,18 +644,11 @@ impl<'a> EvalCircuit> for HighPrecisionLutCir key_to_index .iter() - .map(|(&key, &index)| - (key.to_string(), eval_values[index].read().unwrap().clone()) - ) + .map(|(&key, &index)| (key.to_string(), eval_values[index].read().unwrap().clone())) .collect() } - fn decrypt_outputs( - &self, - enc_wire_map: &HashMap, - verbose: bool - ) { + fn decrypt_outputs(&self, enc_wire_map: &HashMap, verbose: bool) { for (i, output_wire) in self.circuit.output_wires.iter().enumerate() { if i > 10 && !verbose { println!( @@ -734,7 +669,6 @@ impl<'a> EvalCircuit> for HighPrecisionLutCir } } - #[test] fn test_gate_evaluation() { let (client_key, server_key) = gen_keys(); @@ -764,8 +698,7 @@ fn test_gate_evaluation() { } let output_value_ptxt = gate.evaluate(&inputs_ptxt, 1); - let output_value_enc = - gate.evaluate_encrypted(&server_key, &inputs_ctxt, 1); + let output_value_enc = gate.evaluate_encrypted(&server_key, &inputs_ctxt, 1); if gate.get_gate_type() == GateType::Lut { continue; } @@ -779,8 +712,7 @@ fn test_gate_evaluation() { #[test] fn test_evaluate_circuit_parallel() { let (gates_set, mut wire_map, input_wires, _, _, _, _) = - crate::verilog_parser::read_verilog_file( - "hdl-benchmarks/processed-netlists/2-bit-adder.v"); + crate::verilog_parser::read_verilog_file("hdl-benchmarks/processed-netlists/2-bit-adder.v"); let empty = vec![]; let mut circuit = Circuit::new(gates_set, &input_wires, &empty, &empty); @@ -803,12 +735,10 @@ fn test_evaluate_circuit_parallel() { assert_eq!(wire_map["i1"], false); } - #[test] fn test_evaluate_encrypted_circuit_parallel() { let (gates_set, wire_map_im, input_wires, _, _, _, _) = - crate::verilog_parser::read_verilog_file( - "hdl-benchmarks/processed-netlists/2-bit-adder.v"); + crate::verilog_parser::read_verilog_file("hdl-benchmarks/processed-netlists/2-bit-adder.v"); let mut ptxt_wire_map = wire_map_im.clone(); let empty = vec![]; @@ -853,9 +783,10 @@ fn test_evaluate_encrypted_circuit_parallel() { fn test_evaluate_encrypted_lut_circuit_parallel() { let (gates_set, wire_map_im, input_wires, _, _, _, _) = crate::verilog_parser::read_verilog_file( - "hdl-benchmarks/processed-netlists/8-bit-adder-lut-3-1.v"); - let input_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/8-bit-adder.inputs.csv"); + "hdl-benchmarks/processed-netlists/8-bit-adder-lut-3-1.v", + ); + let input_wire_map = + crate::verilog_parser::read_input_wires("hdl-benchmarks/test-cases/8-bit-adder.inputs.csv"); let empty = vec![]; let mut circuit_ptxt = Circuit::new(gates_set, &input_wires, &empty, &empty); @@ -865,8 +796,7 @@ fn test_evaluate_encrypted_lut_circuit_parallel() { let mut ptxt_wire_map = circuit_ptxt.initialize_wire_map(&wire_map_im, &input_wire_map); // Encrypted - let (client_key, server_key) = - tfhe::shortint::gen_keys(PARAM_MESSAGE_3_CARRY_0); // single bit ctxt + let (client_key, server_key) = tfhe::shortint::gen_keys(PARAM_MESSAGE_3_CARRY_0); // single bit ctxt // Plaintext for input_wire in &input_wires { @@ -874,11 +804,7 @@ fn test_evaluate_encrypted_lut_circuit_parallel() { } ptxt_wire_map = circuit_ptxt.evaluate(&ptxt_wire_map, 1); - let mut circuit = LutCircuit::new( - client_key.clone(), - server_key, - circuit_ptxt - ); + let mut circuit = LutCircuit::new(client_key.clone(), server_key, circuit_ptxt); let mut enc_wire_map = EvalCircuit::encrypt_inputs(&mut circuit, &wire_map_im, &input_wire_map); enc_wire_map = EvalCircuit::evaluate_encrypted(&mut circuit, &enc_wire_map, 1); @@ -901,9 +827,10 @@ fn test_evaluate_encrypted_lut_circuit_parallel() { fn test_evaluate_encrypted_high_precision_lut_circuit_parallel() { let (gates_set, wire_map_im, input_wires, _, _, _, _) = crate::verilog_parser::read_verilog_file( - "hdl-benchmarks/processed-netlists/8-bit-adder-lut-high-precision.v"); - let input_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/8-bit-adder.inputs.csv"); + "hdl-benchmarks/processed-netlists/8-bit-adder-lut-high-precision.v", + ); + let input_wire_map = + crate::verilog_parser::read_input_wires("hdl-benchmarks/test-cases/8-bit-adder.inputs.csv"); let empty = vec![]; let mut circuit_ptxt = Circuit::new(gates_set, &input_wires, &empty, &empty); diff --git a/src/gates.rs b/src/gates.rs index 42ab821..c4b8cb8 100644 --- a/src/gates.rs +++ b/src/gates.rs @@ -130,7 +130,7 @@ impl Gate { } pub fn evaluate(&mut self, input_values: &Vec, cycle: usize) -> bool { - if let Some(output) = self.output.clone() { + if let Some(output) = self.output { if self.cycle == cycle { return output; } @@ -186,9 +186,7 @@ impl Gate { GateType::And => server_key.and(&input_values[0], &input_values[1]), GateType::Dff => input_values[0].clone(), GateType::Lut => panic!("Can't mix LUTs with Boolean gates!"), - GateType::Mux => server_key.mux( - &input_values[2], &input_values[0], &input_values[1] - ), + GateType::Mux => server_key.mux(&input_values[2], &input_values[0], &input_values[1]), GateType::Nand => server_key.nand(&input_values[0], &input_values[1]), GateType::Nor => server_key.nor(&input_values[0], &input_values[1]), GateType::Not => server_key.not(&input_values[0]), @@ -213,11 +211,7 @@ impl Gate { } } - lut( - server_key, - &self.lut_const.as_ref().unwrap(), - input_values, - ) + lut(server_key, self.lut_const.as_ref().unwrap(), input_values) } pub fn evaluate_encrypted_high_precision_lut( @@ -238,11 +232,10 @@ impl Gate { wopbs_shortkey, wopbs_intkey, server_intkey, - &self.lut_const.as_ref().unwrap(), + self.lut_const.as_ref().unwrap(), input_values, ) } - } // Shift the constant by ctxt amount @@ -259,7 +252,7 @@ pub fn lut( let ct_sum = ctxts .iter() .enumerate() - .map(|(i, ct)| sks.scalar_mul(ct, 1 << ctxts.len() - 1 - i)) + .map(|(i, ct)| sks.scalar_mul(ct, 1 << (ctxts.len() - 1 - i))) .fold(sks.create_trivial(0), |acc, ct| sks.add(&acc, &ct)); // Generate LUT entries from lut_const @@ -281,15 +274,14 @@ pub fn high_precision_lut( for block in ctxts { combined_vec.insert(0, block.clone()); } - let radix_ct = BaseRadixCiphertext::>:: - from_blocks(combined_vec); + let radix_ct = + BaseRadixCiphertext::>::from_blocks(combined_vec); // KS to WoPBS - let radix_ct = wk.keyswitch_to_wopbs_params(&sks, &radix_ct); + let radix_ct = wk.keyswitch_to_wopbs_params(sks, &radix_ct); // Generate LUT entries from lut_const - let lut = generate_high_precision_lut_radix_helm( - &wk_si, &radix_ct, eval_luts, lut_const); + let lut = generate_high_precision_lut_radix_helm(wk_si, &radix_ct, eval_luts, lut_const); // Eval PBS let radix_ct = wk.wopbs(&radix_ct, &lut); diff --git a/src/verilog_parser.rs b/src/verilog_parser.rs index 6c7437e..9d419ce 100644 --- a/src/verilog_parser.rs +++ b/src/verilog_parser.rs @@ -66,7 +66,7 @@ fn parse_gate(tokens: &[&str]) -> Gate { let lut_const_str = input_wires.remove(0); let lut_const_int = if lut_const_str.starts_with("0x") { Some( - match usize::from_str_radix(&lut_const_str.trim_start_matches("0x"), 16) { + match usize::from_str_radix(lut_const_str.trim_start_matches("0x"), 16) { Ok(n) => n, Err(_) => panic!("Failed to parse hex"), }, @@ -83,14 +83,7 @@ fn parse_gate(tokens: &[&str]) -> Gate { None }; - Gate::new( - gate_name, - gate_type, - input_wires, - lut_const, - output_wire, - 0, - ) + Gate::new(gate_name, gate_type, input_wires, lut_const, output_wire, 0) } fn parse_range(range_str: &str) -> Option<(usize, usize)> { @@ -228,9 +221,8 @@ pub fn read_input_wires(file_name: &str) -> HashMap { #[test] fn test_parser() { - let (gates, wire_map, inputs, _, _, _, _) = read_verilog_file( - "hdl-benchmarks/processed-netlists/2-bit-adder.v" - ); + let (gates, wire_map, inputs, _, _, _, _) = + read_verilog_file("hdl-benchmarks/processed-netlists/2-bit-adder.v"); assert_eq!(gates.len(), 10); assert_eq!(wire_map.len(), 10); @@ -239,13 +231,10 @@ fn test_parser() { #[test] fn test_input_wires_parser() { - let (_, _, inputs, _, _, _, _) = read_verilog_file( - "hdl-benchmarks/processed-netlists/2-bit-adder.v" - ); + let (_, _, inputs, _, _, _, _) = + read_verilog_file("hdl-benchmarks/processed-netlists/2-bit-adder.v"); - let input_wires_map = read_input_wires( - "hdl-benchmarks/test-cases/2-bit-adder.inputs.csv" - ); + let input_wires_map = read_input_wires("hdl-benchmarks/test-cases/2-bit-adder.inputs.csv"); assert_eq!(input_wires_map.len(), inputs.len()); for input_wire in inputs {