Skip to content

Commit

Permalink
Merge branch 'main' of github.com:TrustworthyComputing/helm
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Sep 11, 2023
2 parents 6fc967b + a1de5f0 commit 42a570e
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 35 deletions.
4 changes: 2 additions & 2 deletions src/bin/helm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use helm::{ascii, circuit, circuit::EvalCircuit, verilog_parser};
use std::time::Instant;
use termion::color;
use tfhe::{
boolean::gen_keys, generate_keys, shortint::parameters::PARAM_MESSAGE_2_CARRY_0, ConfigBuilder,
boolean::gen_keys, generate_keys, shortint::parameters::PARAM_MESSAGE_2_CARRY_1_KS_PBS, ConfigBuilder,
};

fn main() {
Expand Down Expand Up @@ -169,7 +169,7 @@ fn main() {

// LUT mode
let mut start = Instant::now();
let (client_key, server_key) = tfhe::shortint::gen_keys(PARAM_MESSAGE_2_CARRY_0); // single bit ctxt
let (client_key, server_key) = tfhe::shortint::gen_keys(PARAM_MESSAGE_2_CARRY_1_KS_PBS); // single bit ctxt
let mut circuit = circuit::LutCircuit::new(client_key, server_key, circuit_ptxt);
println!("KeyGen done in {} seconds.", start.elapsed().as_secs_f64());

Expand Down
20 changes: 9 additions & 11 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ impl<'a> EvalCircuit<CtxtBool> for GateCircuit<'a> {

for input_wire in self.circuit.input_wires {
// if no inputs are provided, initialize it to false
if input_wire_map.is_empty() {
if input_wire_map.is_empty() || input_wire_map.contains_key("dummy") {
enc_wire_map.insert(input_wire.to_string(), self.client_key.encrypt(false));
} else if !input_wire_map.contains_key(input_wire) {
panic!("\n Input wire \"{}\" not in input wires!", input_wire);
Expand Down Expand Up @@ -530,13 +530,12 @@ impl<'a> EvalCircuit<CtxtBool> for GateCircuit<'a> {
verbose: bool,
) -> HashMap<String, PtxtType> {
let mut decrypted_outputs = HashMap::new();

for output_wire in self.circuit.output_wires {
let decrypted_value = self.client_key.decrypt(&enc_wire_map[output_wire]);
decrypted_outputs.insert(output_wire.clone(), PtxtType::Bool(decrypted_value));
}

for (i, (wire, val)) in decrypted_outputs.iter().enumerate() {
for (i, (wire, val)) in decrypted_outputs.iter().sorted().enumerate() {
if i > 10 && !verbose {
println!(
"{}[!]{} More than ten output_wires, pass `--verbose` to see output.",
Expand Down Expand Up @@ -565,7 +564,7 @@ impl<'a> EvalCircuit<CtxtShortInt> for LutCircuit<'a> {
.collect::<HashMap<_, _>>();
for input_wire in self.circuit.input_wires {
// if no inputs are provided, initialize it to false
if input_wire_map.is_empty() {
if input_wire_map.is_empty() || input_wire_map.contains_key("dummy") {
enc_wire_map.insert(input_wire.to_string(), self.client_key.encrypt(0));
} else if !input_wire_map.contains_key(input_wire) {
panic!("\n Input wire \"{}\" not found in input wires!", input_wire);
Expand Down Expand Up @@ -658,13 +657,12 @@ impl<'a> EvalCircuit<CtxtShortInt> for LutCircuit<'a> {
verbose: bool,
) -> HashMap<String, PtxtType> {
let mut decrypted_outputs = HashMap::new();

for output_wire in self.circuit.output_wires {
let decrypted_value = self.client_key.decrypt(&enc_wire_map[output_wire]);
decrypted_outputs.insert(output_wire.clone(), PtxtType::U64(decrypted_value));
}

for (i, (wire, val)) in decrypted_outputs.iter().enumerate() {
for (i, (wire, val)) in decrypted_outputs.iter().sorted().enumerate() {
if i > 10 && !verbose {
println!(
"{}[!]{} More than ten output_wires, pass `--verbose` to see output.",
Expand Down Expand Up @@ -701,7 +699,7 @@ impl<'a> EvalCircuit<FheType> for ArithCircuit<'a> {
}
for input_wire in self.circuit.input_wires {
// if no inputs are provided, initialize it to false
if input_wire_map.is_empty() {
if input_wire_map.is_empty() || input_wire_map.contains_key("dummy") {
let encrypted_value = match ptxt_type {
"u8" => FheType::U8(FheUint8::try_encrypt(0, &self.client_key).unwrap()),
"u16" => FheType::U16(FheUint16::try_encrypt(0, &self.client_key).unwrap()),
Expand Down Expand Up @@ -940,7 +938,7 @@ impl<'a> EvalCircuit<FheType> for ArithCircuit<'a> {
decrypted_outputs.insert(output_wire.clone(), decrypted);
}

for (i, (wire, val)) in decrypted_outputs.iter().enumerate() {
for (i, (wire, val)) in decrypted_outputs.iter().sorted().enumerate() {
if i > 10 && !verbose {
println!(
"{}[!]{} More than ten output_wires, pass `--verbose` to see output.",
Expand Down Expand Up @@ -969,7 +967,7 @@ impl<'a> EvalCircuit<CtxtShortInt> for HighPrecisionLutCircuit<'a> {
.collect::<HashMap<_, _>>();
for input_wire in self.circuit.input_wires {
// if no inputs are provided, initialize it to false
if input_wire_map.is_empty() {
if input_wire_map.is_empty() || input_wire_map.contains_key("dummy") {
enc_wire_map.insert(input_wire.to_string(), self.client_key.encrypt_one_block(0));
} else if !input_wire_map.contains_key(input_wire) {
panic!("\n Input wire \"{}\" not found in input wires!", input_wire);
Expand Down Expand Up @@ -1072,7 +1070,7 @@ impl<'a> EvalCircuit<CtxtShortInt> for HighPrecisionLutCircuit<'a> {
decrypted_outputs.insert(output_wire.clone(), PtxtType::U64(decrypted));
}

for (i, (wire, val)) in decrypted_outputs.iter().enumerate() {
for (i, (wire, val)) in decrypted_outputs.iter().sorted().enumerate() {
if i > 10 && !verbose {
println!(
"{}[!]{} More than ten output_wires, pass `--verbose` to see output.",
Expand All @@ -1087,4 +1085,4 @@ impl<'a> EvalCircuit<CtxtShortInt> for HighPrecisionLutCircuit<'a> {

decrypted_outputs
}
}
}
39 changes: 27 additions & 12 deletions src/gates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -737,24 +737,39 @@ fn eval_luts(x: u64, lut_table: &Vec<u64>) -> u64 {
lut_table[x as usize] & 1
}

fn eval_luts_bivariate_test(x: u64, y: u64, lut_table: &Vec<u64>) -> u64 {
lut_table[((x & 1) * 2 + (y & 1)) as usize]
}


pub fn lut(
sks: &ServerKeyShortInt,
lut_const: &Vec<u64>,
ctxts: &mut Vec<CiphertextBase>,
) -> CiphertextBase {
// Σ ctxts[i] * 2^i
let ctxts_len = (ctxts.len() - 1) as u8;
let ct_sum = ctxts
.iter_mut()
.enumerate()
.map(|(i, ct)| sks.smart_scalar_left_shift(ct, ctxts_len - i as u8))
.fold(sks.create_trivial(0), |acc, ct| sks.add(&acc, &ct));

// Generate LUT entries from lut_const
let lut = sks.generate_lookup_table(|x| eval_luts(x, lut_const));

// Eval PBS and return
sks.apply_lookup_table(&ct_sum, &lut)
println!("lut constant: {:?}", &lut_const);
if ctxts.len() == 2 {
let mut c0 = ctxts[0].clone();
let wrapped_f = |lhs: u64, rhs: u64| -> u64 { u64::from(eval_luts_bivariate_test(lhs as u64, rhs as u64, lut_const)) };
sks.smart_evaluate_bivariate_function(&mut c0, &mut ctxts[1], wrapped_f)
} else if ctxts.len() == 1 {
println!("lut_const: {:?}", &lut_const);
ctxts[0].clone()
} else {
println!("shouldn't get here!");
let ctxts_len: u8 = (ctxts.len() - 1) as u8;
let ct_sum = ctxts
.iter_mut()
.enumerate()
.map(|(i, ct)| sks.smart_scalar_left_shift(ct, ctxts_len - i as u8))
.fold(sks.create_trivial(0), |acc, ct| sks.add(&acc, &ct));
// Generate LUT entries from lut_const
let lut = sks.generate_lookup_table(|x| eval_luts(x, lut_const));

// Eval PBS and return
sks.apply_lookup_table(&ct_sum, &lut)
}
}

pub fn high_precision_lut(
Expand Down
16 changes: 14 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub enum PtxtError {
InvalidInput,
}

#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
pub enum PtxtType {
Bool(bool),
U8(u8),
Expand Down Expand Up @@ -148,7 +148,19 @@ pub fn get_input_wire_map(
color::Fg(color::Reset)
);

HashMap::new()
let mut input_wire_map = HashMap::new();
let ptxt = match arithmetic_type {
"bool" => PtxtType::Bool(false),
"u8" => PtxtType::U8(0),
"u16" => PtxtType::U16(0),
"u32" => PtxtType::U32(0),
"u64" => PtxtType::U64(0),
"u128" => PtxtType::U128(0),
_ => unreachable!(),
};
input_wire_map.insert("dummy".to_string(), ptxt);

input_wire_map
}
}

Expand Down
9 changes: 4 additions & 5 deletions src/verilog_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ fn extract_const_val(input_str: &str) -> &str {
&input_str[start_index + 1..end_index]
}

fn usize_to_bitvec(value: usize) -> Vec<u64> {
fn usize_to_bitvec(value: usize, lut_size: usize) -> Vec<u64> {
let mut bits: Vec<u64> = Vec::new();

for i in 0..64 {
for i in 0..lut_size {
let bit = ((value >> i) & 1) as u64;
bits.push(bit);
}
Expand Down Expand Up @@ -108,8 +108,7 @@ fn parse_gate(tokens: &[&str]) -> Gate {
Err(_) => panic!("Failed to parse integer"),
})
};

Some(usize_to_bitvec(lut_const_int.unwrap()))
Some(usize_to_bitvec(lut_const_int.unwrap(), 1 << input_wires.len()))
} else {
None
};
Expand Down
6 changes: 3 additions & 3 deletions tests/circuit_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use tfhe::{
shortint::{
parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_1_CARRY_1_KS_PBS,
PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_3_CARRY_0,
PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_2_CARRY_1_KS_PBS,
},
wopbs::WopbsKey as WopbsKeyShortInt,
},
Expand Down Expand Up @@ -106,7 +106,7 @@ fn encrypted_two_bit_adder() {
fn encrypted_eight_bit_adder_lut() {
let datatype = "bool";
let (gates_set, wire_set, input_wires, _, _, _, _) = verilog_parser::read_verilog_file(
"hdl-benchmarks/processed-netlists/8-bit-adder-lut-3-1.v",
"hdl-benchmarks/processed-netlists/8-bit-adder-lut-2-1.v",
false,
);
let input_wire_map = verilog_parser::read_input_wires(
Expand All @@ -123,7 +123,7 @@ fn encrypted_eight_bit_adder_lut() {
let mut ptxt_wire_map = circuit_ptxt.initialize_wire_map(&wire_set, &input_wire_map, datatype);

// Encrypted single bit ctxt
let (client_key, server_key) = tfhe::shortint::gen_keys(PARAM_MESSAGE_3_CARRY_0);
let (client_key, server_key) = tfhe::shortint::gen_keys(PARAM_MESSAGE_2_CARRY_1_KS_PBS);

// Plaintext
for input_wire in &input_wires {
Expand Down

0 comments on commit 42a570e

Please sign in to comment.