diff --git a/src/bin/helm.rs b/src/bin/helm.rs index d2483e6..61e5e03 100644 --- a/src/bin/helm.rs +++ b/src/bin/helm.rs @@ -102,16 +102,8 @@ fn main() { let (gates_set, wire_map_im, input_wires, output_wires, dff_outputs, _, _) = verilog_parser::read_verilog_file(&file_name, true, arithmetic_type); let mut circuit_ptxt = - circuit::Circuit::new(gates_set.clone(), &input_wires, &output_wires, &dff_outputs); + circuit::Circuit::new(gates_set, &input_wires, &output_wires, &dff_outputs); - // TODO(@cgouert): move this check in the parser (and the same for gates/luts) - if gates_set.is_empty() { - panic!( - "{}[!]{} Parser error, no arithmetic gates detected.", - color::Fg(color::LightRed), - color::Fg(color::Reset) - ); - } circuit_ptxt.sort_circuit(); circuit_ptxt.compute_levels(); #[cfg(debug_assertions)] @@ -163,7 +155,7 @@ fn main() { verilog_parser::read_verilog_file(&file_name, false, arithmetic_type); let is_sequential = dff_outputs.len() > 1; let mut circuit_ptxt = - circuit::Circuit::new(gates_set.clone(), &input_wires, &output_wires, &dff_outputs); + circuit::Circuit::new(gates_set, &input_wires, &output_wires, &dff_outputs); if num_cycles > 1 && !is_sequential { panic!( "{}[!]{} Cannot run combinational circuit for more than one cycles.", @@ -171,14 +163,6 @@ fn main() { color::Fg(color::Reset) ); } - if gates_set.is_empty() { - panic!( - "{}[!]{} Parser error, no gates detected. Make sure to use the \ - 'no-expr' flag in Yosys.", - color::Fg(color::LightRed), - color::Fg(color::Reset) - ); - } circuit_ptxt.sort_circuit(); circuit_ptxt.compute_levels(); #[cfg(debug_assertions)] diff --git a/src/gates.rs b/src/gates.rs index 48f872f..7ce83f8 100644 --- a/src/gates.rs +++ b/src/gates.rs @@ -306,13 +306,14 @@ impl Gate { ct2: &FheType, cycle: usize, ) -> FheType { - if let FheType::None = self.encrypted_multibit_output { - if self.cycle == cycle { - return self.encrypted_multibit_output.clone(); + if self.cycle == cycle { + match self.encrypted_multibit_output { + FheType::None => (), + _ => return self.encrypted_multibit_output.clone(), } } - match (ct1, ct2) { + self.encrypted_multibit_output = match (ct1, ct2) { (FheType::Uint32(ct1_value), FheType::Uint32(ct2_value)) => { FheType::Uint32(ct1_value * ct2_value) } @@ -320,7 +321,10 @@ impl Gate { FheType::Uint16(ct1_value * ct2_value) } _ => panic!("evaluate_encrypted_mul_block"), - } + }; + + self.cycle = cycle; + self.encrypted_multibit_output.clone() } pub fn evaluate_encrypted_mul_block_plain( @@ -329,13 +333,14 @@ impl Gate { pt1: PtxtType, cycle: usize, ) -> FheType { - if let FheType::None = self.encrypted_multibit_output { - if self.cycle == cycle { - return self.encrypted_multibit_output.clone(); + if self.cycle == cycle { + match self.encrypted_multibit_output { + FheType::None => (), + _ => return self.encrypted_multibit_output.clone(), } } - match (ct1, pt1) { + self.encrypted_multibit_output = match (ct1, pt1) { (FheType::Uint32(ct1_value), PtxtType::Uint32(pt1_value)) => { FheType::Uint32(ct1_value * pt1_value) } @@ -343,7 +348,10 @@ impl Gate { FheType::Uint16(ct1_value * pt1_value) } _ => panic!("evaluate_encrypted_mul_block_plain"), - } + }; + + self.cycle = cycle; + self.encrypted_multibit_output.clone() } pub fn evaluate_encrypted_add_block( @@ -352,13 +360,14 @@ impl Gate { ct2: &FheType, cycle: usize, ) -> FheType { - if let FheType::None = self.encrypted_multibit_output { - if self.cycle == cycle { - return self.encrypted_multibit_output.clone(); + if self.cycle == cycle { + match self.encrypted_multibit_output { + FheType::None => (), + _ => return self.encrypted_multibit_output.clone(), } } - match (ct1, ct2) { + self.encrypted_multibit_output = match (ct1, ct2) { (FheType::Uint32(ct1_value), FheType::Uint32(ct2_value)) => { FheType::Uint32(ct1_value + ct2_value) } @@ -366,7 +375,10 @@ impl Gate { FheType::Uint16(ct1_value + ct2_value) } _ => panic!("evaluate_encrypted_add_block"), - } + }; + + self.cycle = cycle; + self.encrypted_multibit_output.clone() } pub fn evaluate_encrypted_add_block_plain( @@ -375,13 +387,14 @@ impl Gate { pt1: PtxtType, cycle: usize, ) -> FheType { - if let FheType::None = self.encrypted_multibit_output { - if self.cycle == cycle { - return self.encrypted_multibit_output.clone(); + if self.cycle == cycle { + match self.encrypted_multibit_output { + FheType::None => (), + _ => return self.encrypted_multibit_output.clone(), } } - match (ct1, pt1) { + self.encrypted_multibit_output = match (ct1, pt1) { (FheType::Uint32(ct1_value), PtxtType::Uint32(pt1_value)) => { FheType::Uint32(ct1_value + pt1_value) } @@ -389,7 +402,10 @@ impl Gate { FheType::Uint16(ct1_value + pt1_value) } _ => panic!("evaluate_encrypted_add_block_plain"), - } + }; + + self.cycle = cycle; + self.encrypted_multibit_output.clone() } pub fn evaluate_encrypted_sub_block( @@ -398,13 +414,14 @@ impl Gate { ct2: &FheType, cycle: usize, ) -> FheType { - if let FheType::None = self.encrypted_multibit_output { - if self.cycle == cycle { - return self.encrypted_multibit_output.clone(); + if self.cycle == cycle { + match self.encrypted_multibit_output { + FheType::None => (), + _ => return self.encrypted_multibit_output.clone(), } } - match (ct1, ct2) { + self.encrypted_multibit_output = match (ct1, ct2) { (FheType::Uint32(ct1_value), FheType::Uint32(ct2_value)) => { FheType::Uint32(ct1_value - ct2_value) } @@ -412,7 +429,10 @@ impl Gate { FheType::Uint16(ct1_value - ct2_value) } _ => panic!("evaluate_encrypted_sub_block"), - } + }; + + self.cycle = cycle; + self.encrypted_multibit_output.clone() } pub fn evaluate_encrypted_sub_block_plain( @@ -421,13 +441,14 @@ impl Gate { pt1: PtxtType, cycle: usize, ) -> FheType { - if let FheType::None = self.encrypted_multibit_output { - if self.cycle == cycle { - return self.encrypted_multibit_output.clone(); + if self.cycle == cycle { + match self.encrypted_multibit_output { + FheType::None => (), + _ => return self.encrypted_multibit_output.clone(), } } - match (ct1, pt1) { + self.encrypted_multibit_output = match (ct1, pt1) { (FheType::Uint32(ct1_value), PtxtType::Uint32(pt1_value)) => { FheType::Uint32(ct1_value - pt1_value) } @@ -435,7 +456,10 @@ impl Gate { FheType::Uint16(ct1_value - pt1_value) } _ => panic!("evaluate_encrypted_sub_block_plain"), - } + }; + + self.cycle = cycle; + self.encrypted_multibit_output.clone() } pub fn evaluate_encrypted_dff( @@ -448,8 +472,10 @@ impl Gate { return encrypted_lut_output; } } + let out = input_values[0].clone(); self.encrypted_lut_output = Some(out.clone()); + out } @@ -579,3 +605,206 @@ where } lut } + +#[test] +fn test_caching_of_gate_evaluation() { + use std::time::Instant; + use tfhe::prelude::*; + use tfhe::FheUint16; + use tfhe::set_server_key; + use tfhe::{generate_keys, ConfigBuilder}; + + let config = ConfigBuilder::all_disabled() + .enable_custom_integers( + tfhe::shortint::parameters::PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, + None, + ) + .build(); + let (client_key, server_key) = generate_keys(config); // integer ctxt + set_server_key(server_key); + + let ptxt = vec![10, 20, 30, 40]; + let inputs_ctxt = vec![ + FheType::Uint16(FheUint16::try_encrypt(ptxt[0], &client_key).unwrap()), + FheType::Uint16(FheUint16::try_encrypt(ptxt[1], &client_key).unwrap()), + FheType::Uint16(FheUint16::try_encrypt(ptxt[2], &client_key).unwrap()), + FheType::Uint16(FheUint16::try_encrypt(ptxt[3], &client_key).unwrap()), + ]; + + let mut gates = vec![ + Gate::new( + String::from(""), + GateType::Add, + vec![], + None, + "".to_string(), + 0, + ), + Gate::new( + String::from(""), + GateType::Sub, + vec![], + None, + "".to_string(), + 0, + ), + Gate::new( + String::from(""), + GateType::Mult, + vec![], + None, + "".to_string(), + 0, + ), + ]; + + for gate in gates.iter_mut() { + let mut cycle = 1; + + let mut start = Instant::now(); + let (result, ptxt_result) = match gate.get_gate_type() { + GateType::Add => ( + gate.evaluate_encrypted_add_block(&inputs_ctxt[0], &inputs_ctxt[1], cycle), + PtxtType::Uint16(ptxt[0] + ptxt[1]), + ), + GateType::Sub => ( + gate.evaluate_encrypted_sub_block(&inputs_ctxt[1], &inputs_ctxt[0], cycle), + PtxtType::Uint16(ptxt[1] - ptxt[0]), + ), + GateType::Mult => ( + gate.evaluate_encrypted_mul_block(&inputs_ctxt[0], &inputs_ctxt[1], cycle), + PtxtType::Uint16(ptxt[0] * ptxt[1]), + ), + _ => unreachable!(), + }; + let mut elapsed = start.elapsed().as_secs_f64(); + let mut decrypted = result.decrypt(&client_key); + match gate.get_gate_type() { + GateType::Add => { + println!( + "Cycle {}) {}+{}={} in {} seconds", + cycle, ptxt[0], ptxt[1], decrypted, elapsed + ); + } + GateType::Sub => { + println!( + "Cycle {}) {}-{}={} in {} seconds", + cycle, ptxt[1], ptxt[0], decrypted, elapsed + ); + } + GateType::Mult => { + println!( + "Cycle {}) {}*{}={} in {} seconds", + cycle, ptxt[0], ptxt[1], decrypted, elapsed + ); + } + _ => unreachable!(), + }; + assert_eq!(decrypted, ptxt_result); + + // These should have been cached since the cycle is the same. + start = Instant::now(); + let result = match gate.get_gate_type() { + GateType::Add => { + gate.evaluate_encrypted_add_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle) + } + GateType::Sub => { + gate.evaluate_encrypted_sub_block(&inputs_ctxt[3], &inputs_ctxt[2], cycle) + } + GateType::Mult => { + gate.evaluate_encrypted_mul_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle) + } + _ => unreachable!(), + }; + let elapsed_cached = start.elapsed().as_secs_f64(); + decrypted = result.decrypt(&client_key); + assert_eq!(decrypted, ptxt_result); + assert!(elapsed_cached < elapsed); + + cycle += 1; + + start = Instant::now(); + let (result, ptxt_result) = match gate.get_gate_type() { + GateType::Add => ( + gate.evaluate_encrypted_add_block(&inputs_ctxt[1], &inputs_ctxt[2], cycle), + PtxtType::Uint16(ptxt[1] + ptxt[2]), + ), + GateType::Sub => ( + gate.evaluate_encrypted_sub_block(&inputs_ctxt[2], &inputs_ctxt[1], cycle), + PtxtType::Uint16(ptxt[2] - ptxt[1]), + ), + GateType::Mult => ( + gate.evaluate_encrypted_mul_block(&inputs_ctxt[1], &inputs_ctxt[2], cycle), + PtxtType::Uint16(ptxt[1] * ptxt[2]), + ), + _ => unreachable!(), + }; + elapsed = start.elapsed().as_secs_f64(); + decrypted = result.decrypt(&client_key); + match gate.get_gate_type() { + GateType::Add => { + println!( + "Cycle {}) {}+{}={} in {} seconds", + cycle, ptxt[1], ptxt[2], decrypted, elapsed + ); + } + GateType::Sub => { + println!( + "Cycle {}) {}-{}={} in {} seconds", + cycle, ptxt[2], ptxt[1], decrypted, elapsed + ); + } + GateType::Mult => { + println!( + "Cycle {}) {}*{}={} in {} seconds", + cycle, ptxt[1], ptxt[2], decrypted, elapsed + ); + } + _ => unreachable!(), + }; + assert_eq!(decrypted, ptxt_result); + + cycle += 1; + + start = Instant::now(); + let (result, ptxt_result) = match gate.get_gate_type() { + GateType::Add => ( + gate.evaluate_encrypted_add_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle), + PtxtType::Uint16(ptxt[2] + ptxt[3]), + ), + GateType::Sub => ( + gate.evaluate_encrypted_sub_block(&inputs_ctxt[3], &inputs_ctxt[2], cycle), + PtxtType::Uint16(ptxt[3] - ptxt[2]), + ), + GateType::Mult => ( + gate.evaluate_encrypted_mul_block(&inputs_ctxt[2], &inputs_ctxt[3], cycle), + PtxtType::Uint16(ptxt[2] * ptxt[3]), + ), + _ => unreachable!(), + }; + elapsed = start.elapsed().as_secs_f64(); + decrypted = result.decrypt(&client_key); + match gate.get_gate_type() { + GateType::Add => { + println!( + "Cycle {}) {}+{}={} in {} seconds", + cycle, ptxt[2], ptxt[3], decrypted, elapsed + ); + } + GateType::Sub => { + println!( + "Cycle {}) {}-{}={} in {} seconds", + cycle, ptxt[3], ptxt[2], decrypted, elapsed + ); + } + GateType::Mult => { + println!( + "Cycle {}) {}*{}={} in {} seconds", + cycle, ptxt[2], ptxt[3], decrypted, elapsed + ); + } + _ => unreachable!(), + }; + assert_eq!(decrypted, ptxt_result); + } +} diff --git a/src/verilog_parser.rs b/src/verilog_parser.rs index f0565f9..5707849 100644 --- a/src/verilog_parser.rs +++ b/src/verilog_parser.rs @@ -2,6 +2,7 @@ use csv::Reader; use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::{BufRead, BufReader}; +use termion::color; use crate::gates::{Gate, GateType}; use crate::PtxtType; @@ -240,6 +241,21 @@ pub fn read_verilog_file( } } + if has_arith && gates.is_empty() { + panic!( + "{}[!]{} Parser error, no arithmetic gates detected.", + color::Fg(color::LightRed), + color::Fg(color::Reset) + ); + } else if gates.is_empty() { + panic!( + "{}[!]{} Parser error, no gates detected. Make sure to use the \ + 'no-expr' flag in Yosys.", + color::Fg(color::LightRed), + color::Fg(color::Reset) + ); + } + if has_arith && has_luts { panic!("Can't mix LUTs with arithmetic operators!"); }