Skip to content

Commit

Permalink
Remove unnecessary types
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Aug 28, 2023
1 parent e217f73 commit f91067a
Showing 1 changed file with 43 additions and 59 deletions.
102 changes: 43 additions & 59 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ use std::{
};
use termion::color;
use tfhe::{
boolean::ciphertext::Ciphertext as CtxtBool,
boolean::prelude::*,
integer::{
wopbs::WopbsKey as WopbsKeyInt, ClientKey as ClientKeyInt, ServerKey as ServerKeyInt,
},
prelude::*,
set_server_key, unset_server_key,
set_server_key,
shortint::{
ciphertext::Ciphertext, wopbs::WopbsKey as WopbsKeyShortInt,
ciphertext::Ciphertext as CtxtShortInt, wopbs::WopbsKey as WopbsKeyShortInt,
ClientKey as ClientKeyShortInt, ServerKey as ServerKeyShortInt,
},
unset_server_key, FheUint32,
};

#[cfg(test)]
Expand All @@ -33,23 +35,20 @@ use tfhe::shortint::parameters::{
#[cfg(test)]
use tfhe::{generate_keys, ConfigBuilder};

pub trait EvalCircuit<T> {
type CtxtType;
type PtxtValType;

pub trait EvalCircuit<P, C> {
fn encrypt_inputs(
&mut self,
wire_map_im: &HashMap<String, Self::PtxtValType>,
input_wire_map: &HashMap<String, Self::PtxtValType>,
) -> HashMap<String, T>;
wire_map_im: &HashMap<String, P>,
input_wire_map: &HashMap<String, P>,
) -> HashMap<String, C>;

fn evaluate_encrypted(
&mut self,
enc_wire_map: &HashMap<String, T>,
enc_wire_map: &HashMap<String, C>,
current_cycle: usize,
) -> HashMap<String, T>;
) -> HashMap<String, C>;

fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, T>, verbose: bool);
fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, C>, verbose: bool);
}

pub struct Circuit<'a> {
Expand Down Expand Up @@ -374,15 +373,12 @@ impl<'a> HighPrecisionLutCircuit<'a> {
}
}

impl<'a> EvalCircuit<tfhe::boolean::ciphertext::Ciphertext> for GateCircuit<'a> {
type CtxtType = tfhe::boolean::ciphertext::Ciphertext;
type PtxtValType = bool;

impl<'a> EvalCircuit<bool, CtxtBool> for GateCircuit<'a> {
fn encrypt_inputs(
&mut self,
wire_map_im: &HashMap<String, bool>,
input_wire_map: &HashMap<String, bool>,
) -> HashMap<String, Self::CtxtType> {
) -> HashMap<String, CtxtBool> {
let mut enc_wire_map = HashMap::<String, _>::new();
for (wire, &value) in wire_map_im {
enc_wire_map.insert(wire.to_string(), self.client_key.encrypt(value));
Expand All @@ -409,9 +405,9 @@ impl<'a> EvalCircuit<tfhe::boolean::ciphertext::Ciphertext> for GateCircuit<'a>

fn evaluate_encrypted(
&mut self,
enc_wire_map: &HashMap<String, Self::CtxtType>,
enc_wire_map: &HashMap<String, CtxtBool>,
cycle: usize,
) -> HashMap<String, Self::CtxtType> {
) -> HashMap<String, CtxtBool> {
// Make sure the sort circuit function has run.
assert!(self.circuit.gates.is_empty());
// Make sure the compute_levels function has run.
Expand All @@ -433,7 +429,7 @@ impl<'a> EvalCircuit<tfhe::boolean::ciphertext::Ciphertext> for GateCircuit<'a>
{
// Evaluate all the gates in the level in parallel
gates.par_iter_mut().for_each(|gate| {
let input_values: Vec<Self::CtxtType> = gate
let input_values: Vec<CtxtBool> = gate
.get_input_wires()
.iter()
.map(|input| {
Expand Down Expand Up @@ -464,7 +460,7 @@ impl<'a> EvalCircuit<tfhe::boolean::ciphertext::Ciphertext> for GateCircuit<'a>
.collect()
}

fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, Self::CtxtType>, verbose: bool) {
fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, CtxtBool>, verbose: bool) {
for (i, output_wire) in self.circuit.output_wires.iter().enumerate() {
if i > 10 && !verbose {
println!(
Expand All @@ -484,15 +480,12 @@ impl<'a> EvalCircuit<tfhe::boolean::ciphertext::Ciphertext> for GateCircuit<'a>
}
}

impl<'a> EvalCircuit<Ciphertext> for LutCircuit<'a> {
type CtxtType = Ciphertext;
type PtxtValType = bool;

impl<'a> EvalCircuit<bool, CtxtShortInt> for LutCircuit<'a> {
fn encrypt_inputs(
&mut self,
wire_map_im: &HashMap<String, bool>,
input_wire_map: &HashMap<String, bool>,
) -> HashMap<String, Self::CtxtType> {
) -> HashMap<String, CtxtShortInt> {
let mut enc_wire_map = HashMap::<String, _>::new();
for (wire, &value) in wire_map_im {
enc_wire_map.insert(wire.to_string(), self.client_key.encrypt(value as u64));
Expand All @@ -519,9 +512,9 @@ impl<'a> EvalCircuit<Ciphertext> for LutCircuit<'a> {

fn evaluate_encrypted(
&mut self,
enc_wire_map: &HashMap<String, Self::CtxtType>,
enc_wire_map: &HashMap<String, CtxtShortInt>,
cycle: usize,
) -> HashMap<String, Self::CtxtType> {
) -> HashMap<String, CtxtShortInt> {
// Make sure the sort circuit function has run.
assert!(self.circuit.gates.is_empty());
// Make sure the compute_levels function has run.
Expand All @@ -543,7 +536,7 @@ impl<'a> EvalCircuit<Ciphertext> for LutCircuit<'a> {
{
// Evaluate all the gates in the level in parallel
gates.par_iter_mut().for_each(|gate| {
let input_values: Vec<Self::CtxtType> = gate
let input_values: Vec<CtxtShortInt> = gate
.get_input_wires()
.iter()
.map(|input| {
Expand Down Expand Up @@ -582,7 +575,7 @@ impl<'a> EvalCircuit<Ciphertext> for LutCircuit<'a> {
.collect()
}

fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, Self::CtxtType>, verbose: bool) {
fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, CtxtShortInt>, verbose: bool) {
for (i, output_wire) in self.circuit.output_wires.iter().enumerate() {
if i > 10 && !verbose {
println!(
Expand All @@ -602,21 +595,18 @@ impl<'a> EvalCircuit<Ciphertext> for LutCircuit<'a> {
}
}

impl<'a> EvalCircuit<tfhe::FheUint32> for ArithCircuit<'a> {
type CtxtType = tfhe::FheUint32;
type PtxtValType = u32;

impl<'a> EvalCircuit<u32, FheUint32> for ArithCircuit<'a> {
fn encrypt_inputs(
&mut self,
wire_map_im: &HashMap<String, u32>,
input_wire_map: &HashMap<String, u32>,
) -> HashMap<String, Self::CtxtType> {
) -> HashMap<String, FheUint32> {
let mut enc_wire_map = HashMap::<String, _>::new();
for (wire, &value) in wire_map_im {
if !is_numeric_string(wire) {
enc_wire_map.insert(
wire.to_string(),
Self::CtxtType::try_encrypt(value, &self.client_key).unwrap(),
FheUint32::try_encrypt(value, &self.client_key).unwrap(),
);
}
}
Expand All @@ -625,22 +615,21 @@ impl<'a> EvalCircuit<tfhe::FheUint32> for ArithCircuit<'a> {
if input_wire_map.is_empty() {
enc_wire_map.insert(
input_wire.to_string(),
Self::CtxtType::try_encrypt(0, &self.client_key).unwrap(),
FheUint32::try_encrypt(0, &self.client_key).unwrap(),
);
} else if !input_wire_map.contains_key(input_wire) {
panic!("\n Input wire \"{}\" not found in input wires!", input_wire);
} else {
enc_wire_map.insert(
input_wire.to_string(),
Self::CtxtType::try_encrypt(input_wire_map[input_wire], &self.client_key)
.unwrap(),
FheUint32::try_encrypt(input_wire_map[input_wire], &self.client_key).unwrap(),
);
}
}
for wire in self.circuit.dff_outputs {
enc_wire_map.insert(
wire.to_string(),
Self::CtxtType::try_encrypt(0, &self.client_key).unwrap(),
FheUint32::try_encrypt(0, &self.client_key).unwrap(),
);
}

Expand All @@ -649,9 +638,9 @@ impl<'a> EvalCircuit<tfhe::FheUint32> for ArithCircuit<'a> {

fn evaluate_encrypted(
&mut self,
enc_wire_map: &HashMap<String, Self::CtxtType>,
enc_wire_map: &HashMap<String, FheUint32>,
cycle: usize,
) -> HashMap<String, Self::CtxtType> {
) -> HashMap<String, FheUint32> {
// Make sure the sort circuit function has run.
assert!(self.circuit.gates.is_empty());
// Make sure the compute_levels function has run.
Expand All @@ -675,7 +664,6 @@ impl<'a> EvalCircuit<tfhe::FheUint32> for ArithCircuit<'a> {
{
// Evaluate all the gates in the level in parallel
gates.par_iter_mut().for_each(|gate| {
// set_server_key(self.server_key.clone());
let mut is_ptxt_op = false;
// Identify if any of the input wires are constants
for in_wire in gate.get_input_wires().iter() {
Expand All @@ -686,10 +674,10 @@ impl<'a> EvalCircuit<tfhe::FheUint32> for ArithCircuit<'a> {
let output_value = {
if is_ptxt_op {
let mut ptxt_operand = 0;
let mut ctxt_operand: Option<Self::CtxtType> = None;
let mut ctxt_operand: Option<FheUint32> = None;
for in_wire in gate.get_input_wires().iter() {
if is_numeric_string(in_wire) {
ptxt_operand = in_wire.parse::<Self::PtxtValType>().unwrap_or(0);
ptxt_operand = in_wire.parse::<u32>().unwrap_or(0);
} else {
let index = match key_to_index.get(in_wire) {
Some(&index) => index,
Expand All @@ -713,7 +701,7 @@ impl<'a> EvalCircuit<tfhe::FheUint32> for ArithCircuit<'a> {
gate.evaluate_encrypted_mul_block_plain(&ct_op, ptxt_operand, cycle)
}
} else {
let input_values: Vec<Self::CtxtType> = gate
let input_values: Vec<FheUint32> = gate
.get_input_wires()
.iter()
.map(|input| {
Expand Down Expand Up @@ -760,15 +748,15 @@ impl<'a> EvalCircuit<tfhe::FheUint32> for ArithCircuit<'a> {
});
println!(" Evaluated gates in level [{}/{}]", level, total_levels);
}
rayon::broadcast(|_| unset_server_key());
rayon::broadcast(|_| unset_server_key());

key_to_index
.iter()
.map(|(&key, &index)| (key.to_string(), eval_values[index].read().unwrap().clone()))
.collect()
}

fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, Self::CtxtType>, verbose: bool) {
fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, FheUint32>, verbose: bool) {
for (i, output_wire) in self.circuit.output_wires.iter().enumerate() {
if i > 10 && !verbose {
println!(
Expand All @@ -778,23 +766,19 @@ impl<'a> EvalCircuit<tfhe::FheUint32> for ArithCircuit<'a> {
);
break;
} else {
let decrypted: Self::PtxtValType =
enc_wire_map[output_wire].decrypt(&self.client_key);
let decrypted: u32 = enc_wire_map[output_wire].decrypt(&self.client_key);
println!(" {}: {}", output_wire, decrypted);
}
}
}
}

impl<'a> EvalCircuit<Ciphertext> for HighPrecisionLutCircuit<'a> {
type CtxtType = Ciphertext;
type PtxtValType = bool;

impl<'a> EvalCircuit<bool, CtxtShortInt> for HighPrecisionLutCircuit<'a> {
fn encrypt_inputs(
&mut self,
wire_map_im: &HashMap<String, bool>,
input_wire_map: &HashMap<String, bool>,
) -> HashMap<String, Self::CtxtType> {
) -> HashMap<String, CtxtShortInt> {
let mut enc_wire_map = HashMap::<String, _>::new();
for (wire, &value) in wire_map_im {
enc_wire_map.insert(
Expand Down Expand Up @@ -825,9 +809,9 @@ impl<'a> EvalCircuit<Ciphertext> for HighPrecisionLutCircuit<'a> {

fn evaluate_encrypted(
&mut self,
enc_wire_map: &HashMap<String, Self::CtxtType>,
enc_wire_map: &HashMap<String, CtxtShortInt>,
cycle: usize,
) -> HashMap<String, Self::CtxtType> {
) -> HashMap<String, CtxtShortInt> {
// Make sure the sort circuit function has run.
assert!(self.circuit.gates.is_empty());
// Make sure the compute_levels function has run.
Expand All @@ -849,7 +833,7 @@ impl<'a> EvalCircuit<Ciphertext> for HighPrecisionLutCircuit<'a> {
{
// Evaluate all the gates in the level in parallel
gates.par_iter_mut().for_each(|gate| {
let input_values: Vec<Self::CtxtType> = gate
let input_values: Vec<CtxtShortInt> = gate
.get_input_wires()
.iter()
.map(|input| {
Expand Down Expand Up @@ -888,7 +872,7 @@ impl<'a> EvalCircuit<Ciphertext> for HighPrecisionLutCircuit<'a> {
.collect()
}

fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, Self::CtxtType>, verbose: bool) {
fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, CtxtShortInt>, verbose: bool) {
for (i, output_wire) in self.circuit.output_wires.iter().enumerate() {
if i > 10 && !verbose {
println!(
Expand Down

0 comments on commit f91067a

Please sign in to comment.