diff --git a/src/bin/helm.rs b/src/bin/helm.rs index 67ff8c3..2b34e6d 100644 --- a/src/bin/helm.rs +++ b/src/bin/helm.rs @@ -164,17 +164,17 @@ fn main() { ); // Arithmetic mode - // let config = ConfigBuilder::all_disabled() - // .enable_custom_integers( - // tfhe::shortint::parameters::PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, - // None, - // ) - // .build(); - // let mut start = Instant::now(); - // let (client_key, server_key) = generate_keys(config); // integer ctxt - // set_server_key(server_key); - // let mut circuit = circuit::ArithCircuit::new(client_key, server_key, circuit_ptxt); - // println!("KeyGen done in {} seconds.", start.elapsed().as_secs_f64()); + let config = ConfigBuilder::all_disabled() + .enable_custom_integers( + tfhe::shortint::parameters::PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, + None, + ) + .build(); + let mut start = Instant::now(); + let (client_key, server_key) = generate_keys(config); // integer ctxt + set_server_key(server_key.clone()); + let mut circuit = circuit::ArithCircuit::new(client_key, server_key, circuit_ptxt); + println!("KeyGen done in {} seconds.", start.elapsed().as_secs_f64()); // Client encrypts their inputs // start = Instant::now(); diff --git a/src/circuit.rs b/src/circuit.rs index dfa61c0..1992e38 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -5,6 +5,7 @@ use std::{ collections::{hash_map::Entry, HashMap, HashSet}, sync::{Arc, RwLock}, vec, + default::Default, }; use termion::color; use tfhe::{ @@ -223,26 +224,32 @@ impl<'a> Circuit<'a> { self.ordered_gates.clear(); } - pub fn initialize_wire_map( + pub fn initialize_wire_map( &self, - wire_map_im: &HashMap, - user_inputs: &HashMap, - ) -> HashMap { - let mut wire_map = wire_map_im.clone(); + wire_map_im: &HashMap, + user_inputs: &HashMap, + ) -> HashMap { + let mut wire_map = HashMap::new(); + for (key, value) in wire_map_im.into_iter() { + wire_map.insert(key.clone(), value.clone()); + } for input_wire in self.input_wires { // if no inputs are provided, initialize it to false if user_inputs.is_empty() { - wire_map.insert(input_wire.to_string(), false); + wire_map.insert(input_wire.to_string(), T::default()); } else if !user_inputs.contains_key(input_wire) { panic!("\n Input wire \"{}\" not in input wires!", input_wire); } else { - wire_map.insert(input_wire.to_string(), user_inputs[input_wire]); + if let Some(user_value) = user_inputs.get(input_wire) { + wire_map.insert(input_wire.to_string(), user_value.clone()); + } else { + panic!("\n Input wire \"{}\" not in input wires!", input_wire); + } } } for wire in self.dff_outputs { - wire_map.insert(wire.to_string(), false); + wire_map.insert(wire.to_string(), T::default()); } - wire_map } @@ -684,6 +691,9 @@ impl<'a> EvalCircuit for ArithCircuit<'a> { }; if gate.get_gate_type() == GateType::Add { gate.evaluate_encrypted_add_block_plain(&ct_op, ptxt_operand, cycle) + } else if gate.get_gate_type() == GateType::Sub { + gate.evaluate_encrypted_sub_block_plain(&ct_op, + ptxt_operand, cycle) } else { gate.evaluate_encrypted_mul_block_plain(&ct_op, ptxt_operand, cycle) } @@ -706,9 +716,14 @@ impl<'a> EvalCircuit for ArithCircuit<'a> { .collect(); output_value = { if gate.get_gate_type() == GateType::Add { - gate.evaluate_encrypted_add_block(&input_values[0], &input_values[1], cycle) + gate.evaluate_encrypted_add_block(&input_values[0], + &input_values[1], cycle) + } else if gate.get_gate_type() == GateType::Sub { + gate.evaluate_encrypted_sub_block(&input_values[0], + &input_values[1], cycle) } else { - gate.evaluate_encrypted_mul_block(&input_values[0], &input_values[1], cycle) + gate.evaluate_encrypted_mul_block(&input_values[0], + &input_values[1], cycle) } }; } diff --git a/src/gates.rs b/src/gates.rs index 94f88d6..61f43bf 100644 --- a/src/gates.rs +++ b/src/gates.rs @@ -36,6 +36,7 @@ pub enum GateType { ConstZero, // zero(out); Mult, // mult ID(in0, in1, out); Add, // add ID(in0, in1, out); + Sub, // sub ID(in0, in1, out); } #[derive(Clone)] @@ -151,6 +152,8 @@ impl Gate { input_values.iter().product() } else if self.gate_type == GateType::Add { input_values.iter().sum() + } else if self.gate_type == GateType::Sub { + input_values.iter().fold(0, |diff, &x| diff - x) } else { 0 } @@ -186,6 +189,7 @@ impl Gate { } GateType::Mult => input_values[0], GateType::Add => input_values[0], + GateType::Sub => input_values[0], GateType::Mux => { let select = input_values[2]; (select && input_values[0]) || (!select && input_values[1]) @@ -223,6 +227,7 @@ impl Gate { GateType::Lut => panic!("Can't mix LUTs with Boolean gates!"), GateType::Add => panic!("Add gates can't be mixed with Boolean ops!"), GateType::Mult => panic!("Mult gates can't be mixed with Boolean ops!"), + GateType::Sub => panic!("Sub gates can't be mixed with Boolean ops!"), GateType::Mux => server_key.mux(&input_values[2], &input_values[0], &input_values[1]), GateType::Nand => server_key.nand(&input_values[0], &input_values[1]), GateType::Nor => server_key.nor(&input_values[0], &input_values[1]), @@ -310,6 +315,34 @@ impl Gate { ct1 + pt1 } + pub fn evaluate_encrypted_sub_block( + &mut self, + ct1: &FheUint32, + ct2: &FheUint32, + cycle: usize, + ) -> FheUint32 { + if let Some(encrypted_multibit_output) = self.encrypted_multibit_output.clone() { + if self.cycle == cycle { + return encrypted_multibit_output; + } + } + ct1 - ct2 + } + + pub fn evaluate_encrypted_sub_block_plain( + &mut self, + ct1: &FheUint32, + pt1: u32, + cycle: usize, + ) -> FheUint32 { + if let Some(encrypted_multibit_output) = self.encrypted_multibit_output.clone() { + if self.cycle == cycle { + return encrypted_multibit_output; + } + } + ct1 - pt1 + } + pub fn evaluate_encrypted_dff( &mut self, input_values: &[CiphertextBase],