diff --git a/Cargo.toml b/Cargo.toml index 789a633..7b016ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "helm" -description = "HELM: Homomorphic Evaluation with EDA-driven Logic Minimization" +description = "HELM: Navigating Homomorphic Evaluation through Gates and Lookups" version = "0.1.0" edition = "2021" authors = ["Dimitris Mouris ", "Charles Gouert "] diff --git a/src/bin/helm.rs b/src/bin/helm.rs index 744c759..350552e 100644 --- a/src/bin/helm.rs +++ b/src/bin/helm.rs @@ -2,8 +2,9 @@ use debug_print::debug_println; use helm::{ascii, circuit, circuit::EvalCircuit, verilog_parser}; use std::time::Instant; use termion::color; -use tfhe::{boolean::prelude::*, shortint::parameters::PARAM_MESSAGE_2_CARRY_0}; -use tfhe::{generate_keys, ConfigBuilder}; +use tfhe::{ + boolean::gen_keys, generate_keys, shortint::parameters::PARAM_MESSAGE_2_CARRY_0, ConfigBuilder, +}; fn main() { ascii::print_art(); diff --git a/src/circuit.rs b/src/circuit.rs index 4786cf2..c3e4288 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -22,18 +22,6 @@ use tfhe::{ unset_server_key, FheUint128, FheUint16, FheUint32, FheUint64, FheUint8, }; -#[cfg(test)] -use debug_print::debug_println; -#[cfg(test)] -use rand::Rng; -#[cfg(test)] -use tfhe::shortint::parameters::{ - parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_1_CARRY_1, - PARAM_MESSAGE_3_CARRY_0, -}; -#[cfg(test)] -use tfhe::{generate_keys, ConfigBuilder}; - use crate::{FheType, PtxtType}; pub trait EvalCircuit { @@ -337,6 +325,10 @@ impl<'a> Circuit<'a> { } } + pub fn get_ordered_gates(&self) -> &Vec { + &self.ordered_gates + } + pub fn evaluate(&mut self, wire_map: &HashMap) -> HashMap { // Make sure the sort circuit function has run. assert!(self.gates.is_empty()); @@ -779,7 +771,7 @@ impl<'a> EvalCircuit for ArithCircuit<'a> { set_server_key(self.server_key.clone()); rayon::broadcast(|_| set_server_key(self.server_key.clone())); - + // For each level let total_levels = self.circuit.level_map.len(); for (level, gates) in self @@ -840,9 +832,19 @@ impl<'a> EvalCircuit for ArithCircuit<'a> { } else if gate.get_gate_type() == GateType::Div { gate.evaluate_encrypted_div_block_plain(&ct_op, ptxt_operand, cycle) } else if gate.get_gate_type() == GateType::Shl { - gate.evaluate_encrypted_shift_block_plain(&ct_op, ptxt_operand, cycle, true) + gate.evaluate_encrypted_shift_block_plain( + &ct_op, + ptxt_operand, + cycle, + true, + ) } else if gate.get_gate_type() == GateType::Shr { - gate.evaluate_encrypted_shift_block_plain(&ct_op, ptxt_operand, cycle, false) + gate.evaluate_encrypted_shift_block_plain( + &ct_op, + ptxt_operand, + cycle, + false, + ) } else { unreachable!(); } @@ -882,22 +884,20 @@ impl<'a> EvalCircuit for ArithCircuit<'a> { ) } else if gate.get_gate_type() == GateType::Shl { gate.evaluate_encrypted_shift_block( - &input_values[0], - &input_values[1], - cycle, + &input_values[0], + &input_values[1], + cycle, true, ) } else if gate.get_gate_type() == GateType::Shr { gate.evaluate_encrypted_shift_block( - &input_values[0], - &input_values[1], - cycle, - false) + &input_values[0], + &input_values[1], + cycle, + false, + ) } else if gate.get_gate_type() == GateType::Copy { - gate.evaluate_encrypted_copy_block( - &input_values[0], - cycle, - ) + gate.evaluate_encrypted_copy_block(&input_values[0], cycle) } else { gate.evaluate_encrypted_mul_block( &input_values[0], @@ -965,12 +965,7 @@ impl<'a> EvalCircuit for HighPrecisionLutCircuit<'a> { ) -> HashMap { let mut enc_wire_map = wire_set .iter() - .map(|wire| { - ( - wire.to_string(), - self.client_key.encrypt_one_block(0u64), - ) - }) + .map(|wire| (wire.to_string(), self.client_key.encrypt_one_block(0u64))) .collect::>(); for input_wire in self.circuit.input_wires { // if no inputs are provided, initialize it to false @@ -1093,460 +1088,3 @@ impl<'a> EvalCircuit for HighPrecisionLutCircuit<'a> { decrypted_outputs } } - -#[test] -fn test_gate_evaluation() { - let (client_key, server_key) = gen_keys(); - - let ptxts = vec![PtxtType::Bool(true), PtxtType::Bool(false)]; - let ctxts = vec![client_key.encrypt(true), client_key.encrypt(false)]; - let gates = vec![ - Gate::new( - String::from(""), - GateType::And, - vec![], - None, - "".to_string(), - 0, - ), - Gate::new( - String::from(""), - GateType::Or, - vec![], - None, - "".to_string(), - 0, - ), - Gate::new( - String::from(""), - GateType::Nor, - vec![], - None, - "".to_string(), - 0, - ), - Gate::new( - String::from(""), - GateType::Xor, - vec![], - None, - "".to_string(), - 0, - ), - Gate::new( - String::from(""), - GateType::Nand, - vec![], - None, - "".to_string(), - 0, - ), - Gate::new( - String::from(""), - GateType::Not, - vec![], - None, - "".to_string(), - 0, - ), - Gate::new( - String::from(""), - GateType::Xnor, - vec![], - None, - "".to_string(), - 0, - ), - Gate::new( - String::from(""), - GateType::Mux, - vec![], - None, - "".to_string(), - 0, - ), - ]; - let mut rng = rand::thread_rng(); - let mut cycle = 1; - for mut gate in gates { - for i in 0..2 { - for j in 0..2 { - let mut inputs_ptxt = vec![ptxts[i], ptxts[j]]; - let mut inputs_ctxt = vec![ctxts[i].clone(), ctxts[j].clone()]; - if gate.get_gate_type() == GateType::Mux { - let select: bool = rng.gen(); - inputs_ptxt.push(PtxtType::Bool(select)); - inputs_ctxt.push(client_key.encrypt(select)); - } - let output_value_ptxt = gate.evaluate(&inputs_ptxt); - - let output_value_ctxt = gate.evaluate_encrypted(&server_key, &inputs_ctxt, cycle); - if gate.get_gate_type() == GateType::Lut { - continue; - } - - assert_eq!( - output_value_ptxt, - PtxtType::Bool(client_key.decrypt(&output_value_ctxt)) - ); - - cycle += 1; - } - } - } -} - -#[test] -fn test_evaluate_circuit() { - let (gates_set, wire_set, input_wires, _, _, _, _) = crate::verilog_parser::read_verilog_file( - "hdl-benchmarks/processed-netlists/2-bit-adder.v", - false, - ); - - let empty = vec![]; - let mut circuit = Circuit::new(gates_set, &input_wires, &empty, &empty); - circuit.sort_circuit(); - assert_eq!(circuit.ordered_gates.len(), 10); - circuit.compute_levels(); - - let mut wire_map = HashMap::new(); - for wire in &wire_set { - wire_map.insert(wire.to_string(), PtxtType::Bool(true)); - } - for input_wire in &input_wires { - wire_map.insert(input_wire.to_string(), PtxtType::Bool(true)); - } - wire_map = circuit.evaluate(&wire_map); - - assert_eq!(wire_map.len(), 15); - assert_eq!(input_wires.len(), 5); - - assert_eq!(wire_map["sum[0]"], PtxtType::Bool(true)); - assert_eq!(wire_map["sum[1]"], PtxtType::Bool(true)); - assert_eq!(wire_map["cout"], PtxtType::Bool(true)); - assert_eq!(wire_map["i0"], PtxtType::Bool(false)); - assert_eq!(wire_map["i1"], PtxtType::Bool(false)); -} - -#[test] -fn test_evaluate_encrypted_circuit() { - let datatype = "bool"; - let (gates_set, wire_set, input_wires, _, _, _, _) = crate::verilog_parser::read_verilog_file( - "hdl-benchmarks/processed-netlists/2-bit-adder.v", - false, - ); - - let empty = vec![]; - let mut circuit = Circuit::new(gates_set, &input_wires, &empty, &empty); - circuit.sort_circuit(); - circuit.compute_levels(); - - // Encrypted - let (client_key, server_key) = gen_keys(); - - // Plaintext - let mut ptxt_wire_map = HashMap::new(); - for wire in &wire_set { - ptxt_wire_map.insert(wire.to_string(), PtxtType::Bool(true)); - } - for input_wire in &input_wires { - ptxt_wire_map.insert(input_wire.to_string(), PtxtType::Bool(true)); - } - ptxt_wire_map = circuit.evaluate(&ptxt_wire_map); - - let mut enc_wire_map = HashMap::new(); - for wire in wire_set { - enc_wire_map.insert(wire, client_key.encrypt(false)); - } - for input_wire in &input_wires { - enc_wire_map.insert(input_wire.to_string(), client_key.encrypt(true)); - } - let mut circuit = GateCircuit::new(client_key.clone(), server_key, circuit); - - enc_wire_map = EvalCircuit::evaluate_encrypted(&mut circuit, &enc_wire_map, 1, datatype); - - let mut dec_wire_map = HashMap::new(); - for wire_name in enc_wire_map.keys().sorted() { - dec_wire_map.insert( - wire_name.to_string(), - client_key.decrypt(&enc_wire_map[wire_name]), - ); - } - - // Check that encrypted and plaintext evaluations are equal - for key in ptxt_wire_map.keys() { - assert_eq!(ptxt_wire_map[key], PtxtType::Bool(dec_wire_map[key])); - } -} - -#[test] -fn test_evaluate_encrypted_lut_circuit() { - let datatype = "bool"; - let (gates_set, wire_set, input_wires, _, _, _, _) = crate::verilog_parser::read_verilog_file( - "hdl-benchmarks/processed-netlists/8-bit-adder-lut-3-1.v", - false, - ); - let input_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/8-bit-adder.inputs.csv", - datatype, - ); - - let empty = vec![]; - let mut circuit_ptxt = Circuit::new(gates_set, &input_wires, &empty, &empty); - - circuit_ptxt.sort_circuit(); - circuit_ptxt.compute_levels(); - - let mut ptxt_wire_map = circuit_ptxt.initialize_wire_map(&wire_set, &input_wire_map, datatype); - - // Encrypted single bit ctxt - let (client_key, server_key) = tfhe::shortint::gen_keys(PARAM_MESSAGE_3_CARRY_0); - - // Plaintext - for input_wire in &input_wires { - ptxt_wire_map.insert(input_wire.to_string(), input_wire_map[input_wire]); - } - ptxt_wire_map = circuit_ptxt.evaluate(&ptxt_wire_map); - - let mut circuit = LutCircuit::new(client_key.clone(), server_key, circuit_ptxt); - let mut enc_wire_map = EvalCircuit::encrypt_inputs(&mut circuit, &wire_set, &input_wire_map); - enc_wire_map = EvalCircuit::evaluate_encrypted(&mut circuit, &enc_wire_map, 1, datatype); - - let mut dec_wire_map = HashMap::new(); - for wire_name in enc_wire_map.keys().sorted() { - dec_wire_map.insert( - wire_name.to_string(), - client_key.decrypt(&enc_wire_map[wire_name]) == 1, - ); - } - - // Check that encrypted and plaintext evaluations are equal - for key in ptxt_wire_map.keys() { - assert_eq!(ptxt_wire_map[key], PtxtType::Bool(dec_wire_map[key])); - } - debug_println!("wire map: {:?}", dec_wire_map); -} - -#[test] -fn test_evaluate_encrypted_high_precision_lut_circuit() { - let datatype = "bool"; - let (gates_set, wire_set, input_wires, _, _, _, _) = crate::verilog_parser::read_verilog_file( - "hdl-benchmarks/processed-netlists/8-bit-adder-lut-high-precision.v", - false, - ); - let input_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/8-bit-adder.inputs.csv", - datatype, - ); - - let empty = vec![]; - let mut circuit_ptxt = Circuit::new(gates_set, &input_wires, &empty, &empty); - circuit_ptxt.sort_circuit(); - circuit_ptxt.compute_levels(); - let mut ptxt_wire_map = circuit_ptxt.initialize_wire_map(&wire_set, &input_wire_map, datatype); - - // Encrypted - let (client_key_shortint, server_key_shortint) = - tfhe::shortint::gen_keys(PARAM_MESSAGE_1_CARRY_1); // single bit ctxt - let client_key = ClientKeyInt::from(client_key_shortint.clone()); - let server_key = ServerKeyInt::from_shortint(&client_key, server_key_shortint.clone()); - - let wopbs_key_shortint = WopbsKeyShortInt::new_wopbs_key( - &client_key_shortint, - &server_key_shortint, - &WOPBS_PARAM_MESSAGE_1_CARRY_1_KS_PBS, - ); - let wopbs_key = WopbsKeyInt::from(wopbs_key_shortint.clone()); - - // Plaintext - for input_wire in &input_wires { - ptxt_wire_map.insert(input_wire.to_string(), input_wire_map[input_wire]); - } - ptxt_wire_map = circuit_ptxt.evaluate(&ptxt_wire_map); - - let mut circuit = HighPrecisionLutCircuit::new( - wopbs_key_shortint, - wopbs_key, - client_key.clone(), - server_key, - circuit_ptxt, - ); - let mut enc_wire_map = EvalCircuit::encrypt_inputs(&mut circuit, &wire_set, &input_wire_map); - enc_wire_map = EvalCircuit::evaluate_encrypted(&mut circuit, &enc_wire_map, 1, datatype); - - let mut dec_wire_map = HashMap::new(); - for wire_name in enc_wire_map.keys().sorted() { - dec_wire_map.insert( - wire_name.to_string(), - client_key.decrypt_one_block(&enc_wire_map[wire_name]), - ); - } - - // Check that encrypted and plaintext evaluations are equal - for key in ptxt_wire_map.keys() { - assert_eq!(ptxt_wire_map[key], PtxtType::Bool(dec_wire_map[key] != 0)); - } - debug_println!("wire map: {:?}", dec_wire_map); -} - -#[test] -fn test_evaluate_encrypted_arithmetic_circuit() { - let datatype = "u16"; - let (gates_set, wire_set, input_wires, _, _, _, _) = crate::verilog_parser::read_verilog_file( - "hdl-benchmarks/processed-netlists/chi_squared_arith.v", - true, - ); - let empty = vec![]; - let mut circuit_ptxt = Circuit::new(gates_set, &input_wires, &empty, &empty); - circuit_ptxt.sort_circuit(); - circuit_ptxt.compute_levels(); - - let config = ConfigBuilder::all_disabled() - .enable_custom_integers( - tfhe::shortint::parameters::PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, - None, - ) - .build(); - let (client_key, server_key) = generate_keys(config); // integer ctxt - let mut circuit = ArithCircuit::new(client_key.clone(), server_key, circuit_ptxt); - - // Input set 1 - let input_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/chi_squared_arith_1.inputs.csv", - datatype, - ); - let output_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/chi_squared_arith_1.outputs.csv", - datatype, - ); - - let mut enc_wire_map = EvalCircuit::encrypt_inputs(&mut circuit, &wire_set, &input_wire_map); - enc_wire_map = EvalCircuit::evaluate_encrypted(&mut circuit, &enc_wire_map, 1, datatype); - - // Check that the evaluation was correct - for (wire_name, value) in output_wire_map { - match (enc_wire_map[&wire_name].decrypt(&client_key), value) { - (PtxtType::U8(value), PtxtType::U8(expected_val)) => { - assert_eq!(value, expected_val) - } - (PtxtType::U16(value), PtxtType::U16(expected_val)) => { - assert_eq!(value, expected_val) - } - (PtxtType::U32(value), PtxtType::U32(expected_val)) => { - assert_eq!(value, expected_val) - } - (PtxtType::U64(value), PtxtType::U64(expected_val)) => { - assert_eq!(value, expected_val) - } - (PtxtType::U128(value), PtxtType::U128(expected_val)) => { - assert_eq!(value, expected_val) - } - _ => panic!("Decrypted shouldn't be None"), - }; - } - - // Input set 2 - let input_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/chi_squared_arith_2.inputs.csv", - datatype, - ); - let output_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/chi_squared_arith_2.outputs.csv", - datatype, - ); - - let mut enc_wire_map = EvalCircuit::encrypt_inputs(&mut circuit, &wire_set, &input_wire_map); - enc_wire_map = EvalCircuit::evaluate_encrypted(&mut circuit, &enc_wire_map, 2, datatype); - - // Check that the evaluation was correct - for (wire_name, value) in output_wire_map { - match (enc_wire_map[&wire_name].decrypt(&client_key), value) { - (PtxtType::U8(val), PtxtType::U8(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U16(val), PtxtType::U16(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U32(val), PtxtType::U32(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U64(val), PtxtType::U64(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U128(val), PtxtType::U128(expected_val)) => { - assert_eq!(val, expected_val) - } - _ => panic!("Decrypted shouldn't be None"), - }; - } - - // Input set 3 - let input_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/chi_squared_arith_3.inputs.csv", - datatype, - ); - let output_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/chi_squared_arith_3.outputs.csv", - datatype, - ); - - let mut enc_wire_map = EvalCircuit::encrypt_inputs(&mut circuit, &wire_set, &input_wire_map); - enc_wire_map = EvalCircuit::evaluate_encrypted(&mut circuit, &enc_wire_map, 3, datatype); - - // Check that the evaluation was correct - for (wire_name, value) in output_wire_map { - match (enc_wire_map[&wire_name].decrypt(&client_key), value) { - (PtxtType::U8(val), PtxtType::U8(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U16(val), PtxtType::U16(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U32(val), PtxtType::U32(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U64(val), PtxtType::U64(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U128(val), PtxtType::U128(expected_val)) => { - assert_eq!(val, expected_val) - } - _ => panic!("Decrypted shouldn't be None"), - }; - } - - // Input set 4 - let input_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/chi_squared_arith_4.inputs.csv", - datatype, - ); - let output_wire_map = crate::verilog_parser::read_input_wires( - "hdl-benchmarks/test-cases/chi_squared_arith_4.outputs.csv", - datatype, - ); - - let mut enc_wire_map = EvalCircuit::encrypt_inputs(&mut circuit, &wire_set, &input_wire_map); - enc_wire_map = EvalCircuit::evaluate_encrypted(&mut circuit, &enc_wire_map, 4, datatype); - - // Check that the evaluation was correct - for (wire_name, value) in output_wire_map { - match (enc_wire_map[&wire_name].decrypt(&client_key), value) { - (PtxtType::U8(val), PtxtType::U8(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U16(val), PtxtType::U16(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U32(val), PtxtType::U32(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U64(val), PtxtType::U64(expected_val)) => { - assert_eq!(val, expected_val) - } - (PtxtType::U128(val), PtxtType::U128(expected_val)) => { - assert_eq!(val, expected_val) - } - _ => panic!("Decrypted shouldn't be None"), - }; - } -} diff --git a/src/gates.rs b/src/gates.rs index dcd5fce..c85042c 100644 --- a/src/gates.rs +++ b/src/gates.rs @@ -293,11 +293,7 @@ impl Gate { ret } - pub fn evaluate_encrypted_copy_block( - &mut self, - ct1: &FheType, - cycle: usize, - ) -> FheType { + pub fn evaluate_encrypted_copy_block(&mut self, ct1: &FheType, cycle: usize) -> FheType { if self.cycle == cycle { match self.encrypted_multibit_output { FheType::None => (), @@ -484,7 +480,7 @@ impl Gate { ct1: &FheType, ct2: &FheType, cycle: usize, - dir : bool, // true for left shifts + dir: bool, // true for left shifts ) -> FheType { if self.cycle == cycle { match self.encrypted_multibit_output { @@ -551,7 +547,7 @@ impl Gate { } self.encrypted_multibit_output = match (ct1, pt1) { - (FheType::U8(ct1_value), PtxtType::U8(pt1_value)) => { + (FheType::U8(ct1_value), PtxtType::U8(pt1_value)) => { if dir { FheType::U8(ct1_value << pt1_value) } else { @@ -839,206 +835,3 @@ where } lut } - -#[test] -fn test_caching_of_gate_evaluation() { - use std::time::Instant; - use tfhe::prelude::*; - use tfhe::set_server_key; - use tfhe::FheUint16; - use tfhe::{generate_keys, ConfigBuilder}; - - let config = ConfigBuilder::all_disabled() - .enable_custom_integers( - tfhe::shortint::parameters::PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, - None, - ) - .build(); - let (client_key, server_key) = generate_keys(config); // integer ctxt - set_server_key(server_key); - - let ptxt = vec![10, 20, 30, 40]; - let inputs_ctxt = vec![ - FheType::U16(FheUint16::try_encrypt(ptxt[0], &client_key).unwrap()), - FheType::U16(FheUint16::try_encrypt(ptxt[1], &client_key).unwrap()), - FheType::U16(FheUint16::try_encrypt(ptxt[2], &client_key).unwrap()), - FheType::U16(FheUint16::try_encrypt(ptxt[3], &client_key).unwrap()), - ]; - - let mut gates = vec![ - Gate::new( - String::from(""), - GateType::Add, - vec![], - None, - "".to_string(), - 0, - ), - Gate::new( - String::from(""), - GateType::Sub, - vec![], - None, - "".to_string(), - 0, - ), - Gate::new( - String::from(""), - GateType::Mult, - vec![], - None, - "".to_string(), - 0, - ), - ]; - - for gate in gates.iter_mut() { - let mut cycle = 1; - - let mut start = Instant::now(); - let (result, ptxt_result) = match gate.get_gate_type() { - GateType::Add => ( - gate.evaluate_encrypted_add_block(&inputs_ctxt[0], &inputs_ctxt[1], cycle), - PtxtType::U16(ptxt[0] + ptxt[1]), - ), - GateType::Sub => ( - gate.evaluate_encrypted_sub_block(&inputs_ctxt[1], &inputs_ctxt[0], cycle), - PtxtType::U16(ptxt[1] - ptxt[0]), - ), - GateType::Mult => ( - gate.evaluate_encrypted_mul_block(&inputs_ctxt[0], &inputs_ctxt[1], cycle), - PtxtType::U16(ptxt[0] * ptxt[1]), - ), - _ => unreachable!(), - }; - let mut elapsed = start.elapsed().as_secs_f64(); - let mut decrypted = result.decrypt(&client_key); - match gate.get_gate_type() { - GateType::Add => { - println!( - "Cycle {}) {}+{}={} in {} seconds", - cycle, ptxt[0], ptxt[1], decrypted, elapsed - ); - } - GateType::Sub => { - println!( - "Cycle {}) {}-{}={} in {} seconds", - cycle, ptxt[1], ptxt[0], decrypted, elapsed - ); - } - GateType::Mult => { - println!( - "Cycle {}) {}*{}={} in {} seconds", - cycle, ptxt[0], ptxt[1], decrypted, elapsed - ); - } - _ => unreachable!(), - }; - assert_eq!(decrypted, ptxt_result); - - // These should have been cached since the cycle is the same. - start = Instant::now(); - let result = match gate.get_gate_type() { - GateType::Add => { - gate.evaluate_encrypted_add_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle) - } - GateType::Sub => { - gate.evaluate_encrypted_sub_block(&inputs_ctxt[3], &inputs_ctxt[2], cycle) - } - GateType::Mult => { - gate.evaluate_encrypted_mul_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle) - } - _ => unreachable!(), - }; - let elapsed_cached = start.elapsed().as_secs_f64(); - decrypted = result.decrypt(&client_key); - assert_eq!(decrypted, ptxt_result); - assert!(elapsed_cached < elapsed); - - cycle += 1; - - start = Instant::now(); - let (result, ptxt_result) = match gate.get_gate_type() { - GateType::Add => ( - gate.evaluate_encrypted_add_block(&inputs_ctxt[1], &inputs_ctxt[2], cycle), - PtxtType::U16(ptxt[1] + ptxt[2]), - ), - GateType::Sub => ( - gate.evaluate_encrypted_sub_block(&inputs_ctxt[2], &inputs_ctxt[1], cycle), - PtxtType::U16(ptxt[2] - ptxt[1]), - ), - GateType::Mult => ( - gate.evaluate_encrypted_mul_block(&inputs_ctxt[1], &inputs_ctxt[2], cycle), - PtxtType::U16(ptxt[1] * ptxt[2]), - ), - _ => unreachable!(), - }; - elapsed = start.elapsed().as_secs_f64(); - decrypted = result.decrypt(&client_key); - match gate.get_gate_type() { - GateType::Add => { - println!( - "Cycle {}) {}+{}={} in {} seconds", - cycle, ptxt[1], ptxt[2], decrypted, elapsed - ); - } - GateType::Sub => { - println!( - "Cycle {}) {}-{}={} in {} seconds", - cycle, ptxt[2], ptxt[1], decrypted, elapsed - ); - } - GateType::Mult => { - println!( - "Cycle {}) {}*{}={} in {} seconds", - cycle, ptxt[1], ptxt[2], decrypted, elapsed - ); - } - _ => unreachable!(), - }; - assert_eq!(decrypted, ptxt_result); - - cycle += 1; - - start = Instant::now(); - let (result, ptxt_result) = match gate.get_gate_type() { - GateType::Add => ( - gate.evaluate_encrypted_add_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle), - PtxtType::U16(ptxt[2] + ptxt[3]), - ), - GateType::Sub => ( - gate.evaluate_encrypted_sub_block(&inputs_ctxt[3], &inputs_ctxt[2], cycle), - PtxtType::U16(ptxt[3] - ptxt[2]), - ), - GateType::Mult => ( - gate.evaluate_encrypted_mul_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle), - PtxtType::U16(ptxt[2] * ptxt[3]), - ), - _ => unreachable!(), - }; - elapsed = start.elapsed().as_secs_f64(); - decrypted = result.decrypt(&client_key); - match gate.get_gate_type() { - GateType::Add => { - println!( - "Cycle {}) {}+{}={} in {} seconds", - cycle, ptxt[2], ptxt[3], decrypted, elapsed - ); - } - GateType::Sub => { - println!( - "Cycle {}) {}-{}={} in {} seconds", - cycle, ptxt[3], ptxt[2], decrypted, elapsed - ); - } - GateType::Mult => { - println!( - "Cycle {}) {}*{}={} in {} seconds", - cycle, ptxt[2], ptxt[3], decrypted, elapsed - ); - } - _ => unreachable!(), - }; - assert_eq!(decrypted, ptxt_result); - } -} diff --git a/src/lib.rs b/src/lib.rs index 9b9e88f..0f28be7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,7 +75,7 @@ impl fmt::Display for PtxtType { } impl FheType { - fn decrypt(&self, client_key: &tfhe::ClientKey) -> PtxtType { + pub fn decrypt(&self, client_key: &tfhe::ClientKey) -> PtxtType { match self { FheType::U8(inner_value) => PtxtType::U8(inner_value.decrypt(client_key)), FheType::U16(inner_value) => PtxtType::U16(inner_value.decrypt(client_key)), @@ -87,6 +87,7 @@ impl FheType { } } +// TODO // arithmetic -i a 15 // boolean: 1) -i a[0] 1 -i a[1] 0 ... // boolean: 1) -i aeskey 0 ... @@ -112,7 +113,6 @@ pub fn get_input_wire_map( ); // [[wire1, value1], [wire2, value2], [wire3, value3]] - wire_inputs .iter() .map(|parts| { diff --git a/src/verilog_parser.rs b/src/verilog_parser.rs index 7ba003d..ba16def 100644 --- a/src/verilog_parser.rs +++ b/src/verilog_parser.rs @@ -229,7 +229,7 @@ pub fn read_verilog_file( || gate.get_gate_type() == GateType::Sub || gate.get_gate_type() == GateType::Mult || gate.get_gate_type() == GateType::Div - || gate.get_gate_type() == GateType::Shl + || gate.get_gate_type() == GateType::Shl || gate.get_gate_type() == GateType::Shr || gate.get_gate_type() == GateType::Copy { @@ -346,53 +346,3 @@ pub fn write_output_wires(file_name: Option, input_map: &HashMap { + assert_eq!(value, expected_val) + } + (PtxtType::U16(value), PtxtType::U16(expected_val)) => { + assert_eq!(value, expected_val) + } + (PtxtType::U32(value), PtxtType::U32(expected_val)) => { + assert_eq!(value, expected_val) + } + (PtxtType::U64(value), PtxtType::U64(expected_val)) => { + assert_eq!(value, expected_val) + } + (PtxtType::U128(value), PtxtType::U128(expected_val)) => { + assert_eq!(value, expected_val) + } + _ => panic!("Decrypted shouldn't be None"), + }; + } + + // Input set 2 + let input_wire_map = verilog_parser::read_input_wires( + "hdl-benchmarks/test-cases/chi_squared_arith_2.inputs.csv", + datatype, + ); + let output_wire_map = verilog_parser::read_input_wires( + "hdl-benchmarks/test-cases/chi_squared_arith_2.outputs.csv", + datatype, + ); + + let mut enc_wire_map = EvalCircuit::encrypt_inputs(&mut circuit, &wire_set, &input_wire_map); + enc_wire_map = EvalCircuit::evaluate_encrypted(&mut circuit, &enc_wire_map, 2, datatype); + + // Check that the evaluation was correct + for (wire_name, value) in output_wire_map { + match (enc_wire_map[&wire_name].decrypt(&client_key), value) { + (PtxtType::U8(val), PtxtType::U8(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U16(val), PtxtType::U16(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U32(val), PtxtType::U32(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U64(val), PtxtType::U64(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U128(val), PtxtType::U128(expected_val)) => { + assert_eq!(val, expected_val) + } + _ => panic!("Decrypted shouldn't be None"), + }; + } + + // Input set 3 + let input_wire_map = verilog_parser::read_input_wires( + "hdl-benchmarks/test-cases/chi_squared_arith_3.inputs.csv", + datatype, + ); + let output_wire_map = verilog_parser::read_input_wires( + "hdl-benchmarks/test-cases/chi_squared_arith_3.outputs.csv", + datatype, + ); + + let mut enc_wire_map = EvalCircuit::encrypt_inputs(&mut circuit, &wire_set, &input_wire_map); + enc_wire_map = EvalCircuit::evaluate_encrypted(&mut circuit, &enc_wire_map, 3, datatype); + + // Check that the evaluation was correct + for (wire_name, value) in output_wire_map { + match (enc_wire_map[&wire_name].decrypt(&client_key), value) { + (PtxtType::U8(val), PtxtType::U8(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U16(val), PtxtType::U16(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U32(val), PtxtType::U32(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U64(val), PtxtType::U64(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U128(val), PtxtType::U128(expected_val)) => { + assert_eq!(val, expected_val) + } + _ => panic!("Decrypted shouldn't be None"), + }; + } + + // Input set 4 + let input_wire_map = verilog_parser::read_input_wires( + "hdl-benchmarks/test-cases/chi_squared_arith_4.inputs.csv", + datatype, + ); + let output_wire_map = verilog_parser::read_input_wires( + "hdl-benchmarks/test-cases/chi_squared_arith_4.outputs.csv", + datatype, + ); + + let mut enc_wire_map = EvalCircuit::encrypt_inputs(&mut circuit, &wire_set, &input_wire_map); + enc_wire_map = EvalCircuit::evaluate_encrypted(&mut circuit, &enc_wire_map, 4, datatype); + + // Check that the evaluation was correct + for (wire_name, value) in output_wire_map { + match (enc_wire_map[&wire_name].decrypt(&client_key), value) { + (PtxtType::U8(val), PtxtType::U8(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U16(val), PtxtType::U16(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U32(val), PtxtType::U32(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U64(val), PtxtType::U64(expected_val)) => { + assert_eq!(val, expected_val) + } + (PtxtType::U128(val), PtxtType::U128(expected_val)) => { + assert_eq!(val, expected_val) + } + _ => panic!("Decrypted shouldn't be None"), + }; + } +} diff --git a/tests/gates_test.rs b/tests/gates_test.rs new file mode 100644 index 0000000..43a4ccd --- /dev/null +++ b/tests/gates_test.rs @@ -0,0 +1,311 @@ +use helm::{ + gates::{Gate, GateType}, + FheType, PtxtType, +}; +use rand::Rng; +use tfhe::boolean::gen_keys; + +#[test] +fn encrypted_vs_plaintext_gates() { + let (client_key, server_key) = gen_keys(); + + let ptxts = vec![PtxtType::Bool(true), PtxtType::Bool(false)]; + let ctxts = vec![client_key.encrypt(true), client_key.encrypt(false)]; + let gates = vec![ + Gate::new( + String::from(""), + GateType::And, + vec![], + None, + "".to_string(), + 0, + ), + Gate::new( + String::from(""), + GateType::Or, + vec![], + None, + "".to_string(), + 0, + ), + Gate::new( + String::from(""), + GateType::Nor, + vec![], + None, + "".to_string(), + 0, + ), + Gate::new( + String::from(""), + GateType::Xor, + vec![], + None, + "".to_string(), + 0, + ), + Gate::new( + String::from(""), + GateType::Nand, + vec![], + None, + "".to_string(), + 0, + ), + Gate::new( + String::from(""), + GateType::Not, + vec![], + None, + "".to_string(), + 0, + ), + Gate::new( + String::from(""), + GateType::Xnor, + vec![], + None, + "".to_string(), + 0, + ), + Gate::new( + String::from(""), + GateType::Mux, + vec![], + None, + "".to_string(), + 0, + ), + ]; + let mut rng = rand::thread_rng(); + let mut cycle = 1; + for mut gate in gates { + for i in 0..2 { + for j in 0..2 { + let mut inputs_ptxt = vec![ptxts[i], ptxts[j]]; + let mut inputs_ctxt = vec![ctxts[i].clone(), ctxts[j].clone()]; + if gate.get_gate_type() == GateType::Mux { + let select: bool = rng.gen(); + inputs_ptxt.push(PtxtType::Bool(select)); + inputs_ctxt.push(client_key.encrypt(select)); + } + let output_value_ptxt = gate.evaluate(&inputs_ptxt); + + let output_value_ctxt = gate.evaluate_encrypted(&server_key, &inputs_ctxt, cycle); + if gate.get_gate_type() == GateType::Lut { + continue; + } + + assert_eq!( + output_value_ptxt, + PtxtType::Bool(client_key.decrypt(&output_value_ctxt)) + ); + + cycle += 1; + } + } + } +} + +#[test] +fn caching_of_gate_evaluation() { + use std::time::Instant; + use tfhe::prelude::*; + use tfhe::set_server_key; + use tfhe::FheUint16; + use tfhe::{generate_keys, ConfigBuilder}; + + let config = ConfigBuilder::all_disabled() + .enable_custom_integers( + tfhe::shortint::parameters::PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, + None, + ) + .build(); + let (client_key, server_key) = generate_keys(config); // integer ctxt + set_server_key(server_key); + + let ptxt = vec![10, 20, 30, 40]; + let inputs_ctxt = vec![ + FheType::U16(FheUint16::try_encrypt(ptxt[0], &client_key).unwrap()), + FheType::U16(FheUint16::try_encrypt(ptxt[1], &client_key).unwrap()), + FheType::U16(FheUint16::try_encrypt(ptxt[2], &client_key).unwrap()), + FheType::U16(FheUint16::try_encrypt(ptxt[3], &client_key).unwrap()), + ]; + + let mut gates = vec![ + Gate::new( + String::from(""), + GateType::Add, + vec![], + None, + "".to_string(), + 0, + ), + Gate::new( + String::from(""), + GateType::Sub, + vec![], + None, + "".to_string(), + 0, + ), + Gate::new( + String::from(""), + GateType::Mult, + vec![], + None, + "".to_string(), + 0, + ), + ]; + + for gate in gates.iter_mut() { + let mut cycle = 1; + + let mut start = Instant::now(); + let (result, ptxt_result) = match gate.get_gate_type() { + GateType::Add => ( + gate.evaluate_encrypted_add_block(&inputs_ctxt[0], &inputs_ctxt[1], cycle), + PtxtType::U16(ptxt[0] + ptxt[1]), + ), + GateType::Sub => ( + gate.evaluate_encrypted_sub_block(&inputs_ctxt[1], &inputs_ctxt[0], cycle), + PtxtType::U16(ptxt[1] - ptxt[0]), + ), + GateType::Mult => ( + gate.evaluate_encrypted_mul_block(&inputs_ctxt[0], &inputs_ctxt[1], cycle), + PtxtType::U16(ptxt[0] * ptxt[1]), + ), + _ => unreachable!(), + }; + let mut elapsed = start.elapsed().as_secs_f64(); + let mut decrypted = result.decrypt(&client_key); + match gate.get_gate_type() { + GateType::Add => { + println!( + "Cycle {}) {}+{}={} in {} seconds", + cycle, ptxt[0], ptxt[1], decrypted, elapsed + ); + } + GateType::Sub => { + println!( + "Cycle {}) {}-{}={} in {} seconds", + cycle, ptxt[1], ptxt[0], decrypted, elapsed + ); + } + GateType::Mult => { + println!( + "Cycle {}) {}*{}={} in {} seconds", + cycle, ptxt[0], ptxt[1], decrypted, elapsed + ); + } + _ => unreachable!(), + }; + assert_eq!(decrypted, ptxt_result); + + // These should have been cached since the cycle is the same. + start = Instant::now(); + let result = match gate.get_gate_type() { + GateType::Add => { + gate.evaluate_encrypted_add_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle) + } + GateType::Sub => { + gate.evaluate_encrypted_sub_block(&inputs_ctxt[3], &inputs_ctxt[2], cycle) + } + GateType::Mult => { + gate.evaluate_encrypted_mul_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle) + } + _ => unreachable!(), + }; + let elapsed_cached = start.elapsed().as_secs_f64(); + decrypted = result.decrypt(&client_key); + assert_eq!(decrypted, ptxt_result); + assert!(elapsed_cached < elapsed); + + cycle += 1; + + start = Instant::now(); + let (result, ptxt_result) = match gate.get_gate_type() { + GateType::Add => ( + gate.evaluate_encrypted_add_block(&inputs_ctxt[1], &inputs_ctxt[2], cycle), + PtxtType::U16(ptxt[1] + ptxt[2]), + ), + GateType::Sub => ( + gate.evaluate_encrypted_sub_block(&inputs_ctxt[2], &inputs_ctxt[1], cycle), + PtxtType::U16(ptxt[2] - ptxt[1]), + ), + GateType::Mult => ( + gate.evaluate_encrypted_mul_block(&inputs_ctxt[1], &inputs_ctxt[2], cycle), + PtxtType::U16(ptxt[1] * ptxt[2]), + ), + _ => unreachable!(), + }; + elapsed = start.elapsed().as_secs_f64(); + decrypted = result.decrypt(&client_key); + match gate.get_gate_type() { + GateType::Add => { + println!( + "Cycle {}) {}+{}={} in {} seconds", + cycle, ptxt[1], ptxt[2], decrypted, elapsed + ); + } + GateType::Sub => { + println!( + "Cycle {}) {}-{}={} in {} seconds", + cycle, ptxt[2], ptxt[1], decrypted, elapsed + ); + } + GateType::Mult => { + println!( + "Cycle {}) {}*{}={} in {} seconds", + cycle, ptxt[1], ptxt[2], decrypted, elapsed + ); + } + _ => unreachable!(), + }; + assert_eq!(decrypted, ptxt_result); + + cycle += 1; + + start = Instant::now(); + let (result, ptxt_result) = match gate.get_gate_type() { + GateType::Add => ( + gate.evaluate_encrypted_add_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle), + PtxtType::U16(ptxt[2] + ptxt[3]), + ), + GateType::Sub => ( + gate.evaluate_encrypted_sub_block(&inputs_ctxt[3], &inputs_ctxt[2], cycle), + PtxtType::U16(ptxt[3] - ptxt[2]), + ), + GateType::Mult => ( + gate.evaluate_encrypted_mul_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle), + PtxtType::U16(ptxt[2] * ptxt[3]), + ), + _ => unreachable!(), + }; + elapsed = start.elapsed().as_secs_f64(); + decrypted = result.decrypt(&client_key); + match gate.get_gate_type() { + GateType::Add => { + println!( + "Cycle {}) {}+{}={} in {} seconds", + cycle, ptxt[2], ptxt[3], decrypted, elapsed + ); + } + GateType::Sub => { + println!( + "Cycle {}) {}-{}={} in {} seconds", + cycle, ptxt[3], ptxt[2], decrypted, elapsed + ); + } + GateType::Mult => { + println!( + "Cycle {}) {}*{}={} in {} seconds", + cycle, ptxt[2], ptxt[3], decrypted, elapsed + ); + } + _ => unreachable!(), + }; + assert_eq!(decrypted, ptxt_result); + } +} diff --git a/tests/verilog_parser_test.rs b/tests/verilog_parser_test.rs new file mode 100644 index 0000000..6c0d8a6 --- /dev/null +++ b/tests/verilog_parser_test.rs @@ -0,0 +1,51 @@ +use helm::verilog_parser::{read_input_wires, read_verilog_file}; + +#[test] +fn parse_two_bit_adder() { + let (gates, wire_set, inputs, _, _, _, _) = + read_verilog_file("hdl-benchmarks/processed-netlists/2-bit-adder.v", false); + + assert_eq!(gates.len(), 10); + assert_eq!(wire_set.len(), 10); + assert_eq!(inputs.len(), 5); +} + +#[test] +fn input_wires_gates_parser() { + let (_, _, inputs, _, _, _, _) = + read_verilog_file("hdl-benchmarks/processed-netlists/2-bit-adder.v", false); + + let input_wires_map = + read_input_wires("hdl-benchmarks/test-cases/2-bit-adder.inputs.csv", "bool"); + + assert_eq!(input_wires_map.len(), inputs.len()); + for input_wire in inputs { + assert!(input_wires_map.contains_key(&input_wire)); + } +} + +#[test] +fn input_wires_arithmetic_parser() { + let (_, _, inputs, _, _, _, _) = read_verilog_file( + "hdl-benchmarks/processed-netlists/chi_squared_arith.v", + true, + ); + + let input_wires_map = read_input_wires( + "hdl-benchmarks/test-cases/chi_squared_arith_1.inputs.csv", + "u32", + ); + + assert_eq!(input_wires_map.len(), inputs.len()); + for input_wire in inputs { + assert!(input_wires_map.contains_key(&input_wire)); + } +} + +// Check that it crashes if it contains both LUTs and arithmetic. +#[test] +#[should_panic(expected = "Can't mix LUTs with arithmetic operators!")] +fn invalid_arithmetic_with_luts_parser() { + let (_, _, _, _, _, _, _) = + read_verilog_file("hdl-benchmarks/processed-netlists/invalid.v", true); +}