Skip to content

Commit

Permalink
Write outputs in a CSV file if it is passed as a cmd argument
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Aug 30, 2023
1 parent 68194be commit 8ef667c
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 38 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ cargo run --bin preprocessor --release \
--manifest-path=./hdl-benchmarks/Cargo.toml -- \
--input ./hdl-benchmarks/designs/chi_squared.v \
--output ./hdl-benchmarks/processed-netlists/chi_squared_arith.v
cargo run --bin helm --release -- --arithmetic u32 --input ./hdl-benchmarks/processed-netlists/chi_squared_arith.v --wires ./hdl-benchmarks/test-cases/chi_squared_arith_1.inputs.csv
cargo run --bin helm --release -- --arithmetic u32 \
--input ./hdl-benchmarks/processed-netlists/chi_squared_arith.v \
--wires ./hdl-benchmarks/test-cases/chi_squared_arith_1.inputs.csv
```

<p align="center">
Expand Down
40 changes: 31 additions & 9 deletions src/bin/helm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@ use termion::color;
use tfhe::{boolean::prelude::*, shortint::parameters::PARAM_MESSAGE_4_CARRY_0};
use tfhe::{generate_keys, ConfigBuilder};

fn parse_args() -> (String, usize, bool, Option<String>, Option<String>) {
fn parse_args() -> (
String,
usize,
bool,
Option<String>,
Option<String>,
Option<String>,
) {
let matches = Command::new("HELM")
.about("HELM: Homomorphic Evaluation with EDA-driven Logic Minimization")
.arg(
Expand All @@ -26,6 +33,14 @@ fn parse_args() -> (String, usize, bool, Option<String>, Option<String>) {
.help("Input wire values")
.required(false),
)
.arg(
Arg::new("output-wires")
.long("output-wires")
.value_name("FILE")
.help("Path to a file to write the output wires")
.required(false)
.value_parser(clap::value_parser!(String)),
)
.arg(
Arg::new("cycles")
.long("cycles")
Expand All @@ -40,6 +55,7 @@ fn parse_args() -> (String, usize, bool, Option<String>, Option<String>) {
.short('v')
.long("verbose")
.help("Turn verbose printing on")
.required(false)
.action(ArgAction::SetTrue),
)
.arg(
Expand All @@ -62,7 +78,8 @@ fn parse_args() -> (String, usize, bool, Option<String>, Option<String>) {
.expect("Verilog input file is required");
let num_cycles = *matches.get_one::<usize>("cycles").expect("required");
let verbose = matches.get_flag("verbose");
let wires_file = matches.get_one::<String>("wires").cloned();
let input_wires = matches.get_one::<String>("wires").cloned();
let output_wires = matches.get_one::<String>("output-wires").cloned();
let arithmetic = matches.get_one::<String>("arithmetic").cloned();

// TODO: Add support for this.
Expand All @@ -76,13 +93,15 @@ fn parse_args() -> (String, usize, bool, Option<String>, Option<String>) {
num_cycles,
verbose,
arithmetic,
wires_file,
input_wires,
output_wires,
)
}

fn main() {
ascii::print_art();
let (file_name, num_cycles, verbose, arithmetic, wire_file) = parse_args();
let (file_name, num_cycles, verbose, arithmetic, inputs_filename, outputs_filename) =
parse_args();
if let Some(arithmetic_type) = arithmetic {
println!(
"{} -- Arithmetic mode with {} -- {}",
Expand All @@ -96,7 +115,7 @@ fn main() {
_ => unreachable!(),
}

let input_wire_map = get_input_wire_map(wire_file, arithmetic_type);
let input_wire_map = get_input_wire_map(inputs_filename, arithmetic_type);
let (gates_set, wire_map_im, input_wires, output_wires, dff_outputs, _, _) =
verilog_parser::read_verilog_file(&file_name, true, arithmetic_type);
let mut circuit_ptxt =
Expand Down Expand Up @@ -141,14 +160,15 @@ fn main() {
// Client decrypts the output of the circuit
start = Instant::now();
println!("Encrypted Evaluation:");
EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose);
let decrypted_outputs = EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose);
verilog_parser::write_output_wires(outputs_filename, &decrypted_outputs);
println!(
"Decryption done in {} seconds.",
start.elapsed().as_secs_f64()
);
} else {
let arithmetic_type = "bool";
let input_wire_map = get_input_wire_map(wire_file, arithmetic_type);
let input_wire_map = get_input_wire_map(inputs_filename, arithmetic_type);
let (gates_set, wire_map_im, input_wires, output_wires, dff_outputs, has_luts, _) =
verilog_parser::read_verilog_file(&file_name, false, arithmetic_type);
let is_sequential = dff_outputs.len() > 1;
Expand Down Expand Up @@ -226,7 +246,8 @@ fn main() {
// Client decrypts the output of the circuit
start = Instant::now();
println!("Encrypted Evaluation:");
EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose);
let decrypted_outputs = EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose);
verilog_parser::write_output_wires(outputs_filename, &decrypted_outputs);
println!(
"Decryption done in {} seconds.",
start.elapsed().as_secs_f64()
Expand Down Expand Up @@ -271,7 +292,8 @@ fn main() {
// Client decrypts the output of the circuit
start = Instant::now();
println!("Encrypted Evaluation:");
EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose);
let decrypted_outputs = EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose);
verilog_parser::write_output_wires(outputs_filename, &decrypted_outputs);
println!(
"Decryption done in {} seconds.",
start.elapsed().as_secs_f64()
Expand Down
98 changes: 71 additions & 27 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ pub trait EvalCircuit<C> {
ptxt_type: &str,
) -> HashMap<String, C>;

fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, C>, verbose: bool);
fn decrypt_outputs(
&self,
enc_wire_map: &HashMap<String, C>,
verbose: bool,
) -> HashMap<String, PtxtType>;
}

pub struct Circuit<'a> {
Expand Down Expand Up @@ -527,8 +531,19 @@ impl<'a> EvalCircuit<CtxtBool> for GateCircuit<'a> {
.collect()
}

fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, CtxtBool>, verbose: bool) {
for (i, output_wire) in self.circuit.output_wires.iter().enumerate() {
fn decrypt_outputs(
&self,
enc_wire_map: &HashMap<String, CtxtBool>,
verbose: bool,
) -> HashMap<String, PtxtType> {
let mut decrypted_outputs = HashMap::new();

for output_wire in self.circuit.output_wires {
let decrypted_value = self.client_key.decrypt(&enc_wire_map[output_wire]);
decrypted_outputs.insert(output_wire.clone(), PtxtType::Bool(decrypted_value));
}

for (i, (wire, val)) in decrypted_outputs.iter().enumerate() {
if i > 10 && !verbose {
println!(
"{}[!]{} More than ten output_wires, pass `--verbose` to see output.",
Expand All @@ -537,13 +552,11 @@ impl<'a> EvalCircuit<CtxtBool> for GateCircuit<'a> {
);
break;
} else {
println!(
" {}: {}",
output_wire,
self.client_key.decrypt(&enc_wire_map[output_wire])
);
println!(" {}: {}", wire, val);
}
}

decrypted_outputs
}
}

Expand Down Expand Up @@ -651,8 +664,19 @@ impl<'a> EvalCircuit<CtxtShortInt> for LutCircuit<'a> {
.collect()
}

fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, CtxtShortInt>, verbose: bool) {
for (i, output_wire) in self.circuit.output_wires.iter().enumerate() {
fn decrypt_outputs(
&self,
enc_wire_map: &HashMap<String, CtxtShortInt>,
verbose: bool,
) -> HashMap<String, PtxtType> {
let mut decrypted_outputs = HashMap::new();

for output_wire in self.circuit.output_wires {
let decrypted_value = self.client_key.decrypt(&enc_wire_map[output_wire]);
decrypted_outputs.insert(output_wire.clone(), PtxtType::U64(decrypted_value));
}

for (i, (wire, val)) in decrypted_outputs.iter().enumerate() {
if i > 10 && !verbose {
println!(
"{}[!]{} More than ten output_wires, pass `--verbose` to see output.",
Expand All @@ -661,13 +685,11 @@ impl<'a> EvalCircuit<CtxtShortInt> for LutCircuit<'a> {
);
break;
} else {
println!(
" {}: {}",
output_wire,
self.client_key.decrypt(&enc_wire_map[output_wire])
);
println!(" {}: {}", wire, val);
}
}

decrypted_outputs
}
}

Expand Down Expand Up @@ -895,8 +917,19 @@ impl<'a> EvalCircuit<FheType> for ArithCircuit<'a> {
.collect()
}

fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, FheType>, verbose: bool) {
for (i, output_wire) in self.circuit.output_wires.iter().enumerate() {
fn decrypt_outputs(
&self,
enc_wire_map: &HashMap<String, FheType>,
verbose: bool,
) -> HashMap<String, PtxtType> {
let mut decrypted_outputs = HashMap::new();

for output_wire in self.circuit.output_wires {
let decrypted = enc_wire_map[output_wire].decrypt(&self.client_key);
decrypted_outputs.insert(output_wire.clone(), decrypted);
}

for (i, (wire, val)) in decrypted_outputs.iter().enumerate() {
if i > 10 && !verbose {
println!(
"{}[!]{} More than ten output_wires, pass `--verbose` to see output.",
Expand All @@ -905,10 +938,11 @@ impl<'a> EvalCircuit<FheType> for ArithCircuit<'a> {
);
break;
} else {
let decrypted = enc_wire_map[output_wire].decrypt(&self.client_key);
println!(" {}: {}", output_wire, decrypted);
println!(" {}: {}", wire, val);
}
}

decrypted_outputs
}
}

Expand Down Expand Up @@ -1021,8 +1055,21 @@ impl<'a> EvalCircuit<CtxtShortInt> for HighPrecisionLutCircuit<'a> {
.collect()
}

fn decrypt_outputs(&self, enc_wire_map: &HashMap<String, CtxtShortInt>, verbose: bool) {
for (i, output_wire) in self.circuit.output_wires.iter().enumerate() {
fn decrypt_outputs(
&self,
enc_wire_map: &HashMap<String, CtxtShortInt>,
verbose: bool,
) -> HashMap<String, PtxtType> {
let mut decrypted_outputs = HashMap::new();

for output_wire in self.circuit.output_wires {
let decrypted = self
.client_key
.decrypt_one_block(&enc_wire_map[output_wire]);
decrypted_outputs.insert(output_wire.clone(), PtxtType::U64(decrypted));
}

for (i, (wire, val)) in decrypted_outputs.iter().enumerate() {
if i > 10 && !verbose {
println!(
"{}[!]{} More than ten output_wires, pass `--verbose` to see output.",
Expand All @@ -1031,14 +1078,11 @@ impl<'a> EvalCircuit<CtxtShortInt> for HighPrecisionLutCircuit<'a> {
);
break;
} else {
println!(
" {}: {}",
output_wire,
self.client_key
.decrypt_one_block(&enc_wire_map[output_wire])
);
println!(" {}: {}", wire, val);
}
}

decrypted_outputs
}
}

Expand Down
34 changes: 33 additions & 1 deletion src/verilog_parser.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use csv::Reader;
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::io::{BufRead, BufReader, BufWriter, Write};
use termion::color;

use crate::gates::{Gate, GateType};
Expand Down Expand Up @@ -322,6 +322,38 @@ pub fn read_input_wires(file_name: &str, ptxt_type: &str) -> HashMap<String, Ptx
input_map
}

pub fn write_output_wires(file_name: Option<String>, input_map: &HashMap<String, PtxtType>) {
if let Some(file_name) = file_name {
let file = File::create(&file_name).expect("Failed to create CSV file");
let mut writer = BufWriter::new(file);

for (input_wire, ptxt_type) in input_map.iter() {
match ptxt_type {
PtxtType::Bool(value) => {
writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record");
}
PtxtType::U8(value) => {
writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record");
}
PtxtType::U16(value) => {
writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record");
}
PtxtType::U32(value) => {
writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record");
}
PtxtType::U64(value) => {
writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record");
}
PtxtType::U128(value) => {
writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record");
}
PtxtType::None => unreachable!(),
}
}
println!("Decrypted outputs written to {}", file_name);
}
}

#[test]
fn test_parser() {
let (gates, wire_map, inputs, _, _, _, _) = read_verilog_file(
Expand Down

0 comments on commit 8ef667c

Please sign in to comment.