From 8ef667ca14e20b5845d27d438e1bb0583a0da4f3 Mon Sep 17 00:00:00 2001 From: Dimitris Mouris Date: Wed, 30 Aug 2023 15:14:32 -0400 Subject: [PATCH] Write outputs in a CSV file if it is passed as a cmd argument --- README.md | 4 +- src/bin/helm.rs | 40 ++++++++++++++---- src/circuit.rs | 98 +++++++++++++++++++++++++++++++------------ src/verilog_parser.rs | 34 ++++++++++++++- 4 files changed, 138 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 42206d8..6d2c5a5 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,9 @@ cargo run --bin preprocessor --release \ --manifest-path=./hdl-benchmarks/Cargo.toml -- \ --input ./hdl-benchmarks/designs/chi_squared.v \ --output ./hdl-benchmarks/processed-netlists/chi_squared_arith.v -cargo run --bin helm --release -- --arithmetic u32 --input ./hdl-benchmarks/processed-netlists/chi_squared_arith.v --wires ./hdl-benchmarks/test-cases/chi_squared_arith_1.inputs.csv +cargo run --bin helm --release -- --arithmetic u32 \ + --input ./hdl-benchmarks/processed-netlists/chi_squared_arith.v \ + --wires ./hdl-benchmarks/test-cases/chi_squared_arith_1.inputs.csv ```

diff --git a/src/bin/helm.rs b/src/bin/helm.rs index 3b0df70..dade5e9 100644 --- a/src/bin/helm.rs +++ b/src/bin/helm.rs @@ -9,7 +9,14 @@ use termion::color; use tfhe::{boolean::prelude::*, shortint::parameters::PARAM_MESSAGE_4_CARRY_0}; use tfhe::{generate_keys, ConfigBuilder}; -fn parse_args() -> (String, usize, bool, Option, Option) { +fn parse_args() -> ( + String, + usize, + bool, + Option, + Option, + Option, +) { let matches = Command::new("HELM") .about("HELM: Homomorphic Evaluation with EDA-driven Logic Minimization") .arg( @@ -26,6 +33,14 @@ fn parse_args() -> (String, usize, bool, Option, Option) { .help("Input wire values") .required(false), ) + .arg( + Arg::new("output-wires") + .long("output-wires") + .value_name("FILE") + .help("Path to a file to write the output wires") + .required(false) + .value_parser(clap::value_parser!(String)), + ) .arg( Arg::new("cycles") .long("cycles") @@ -40,6 +55,7 @@ fn parse_args() -> (String, usize, bool, Option, Option) { .short('v') .long("verbose") .help("Turn verbose printing on") + .required(false) .action(ArgAction::SetTrue), ) .arg( @@ -62,7 +78,8 @@ fn parse_args() -> (String, usize, bool, Option, Option) { .expect("Verilog input file is required"); let num_cycles = *matches.get_one::("cycles").expect("required"); let verbose = matches.get_flag("verbose"); - let wires_file = matches.get_one::("wires").cloned(); + let input_wires = matches.get_one::("wires").cloned(); + let output_wires = matches.get_one::("output-wires").cloned(); let arithmetic = matches.get_one::("arithmetic").cloned(); // TODO: Add support for this. @@ -76,13 +93,15 @@ fn parse_args() -> (String, usize, bool, Option, Option) { num_cycles, verbose, arithmetic, - wires_file, + input_wires, + output_wires, ) } fn main() { ascii::print_art(); - let (file_name, num_cycles, verbose, arithmetic, wire_file) = parse_args(); + let (file_name, num_cycles, verbose, arithmetic, inputs_filename, outputs_filename) = + parse_args(); if let Some(arithmetic_type) = arithmetic { println!( "{} -- Arithmetic mode with {} -- {}", @@ -96,7 +115,7 @@ fn main() { _ => unreachable!(), } - let input_wire_map = get_input_wire_map(wire_file, arithmetic_type); + let input_wire_map = get_input_wire_map(inputs_filename, arithmetic_type); let (gates_set, wire_map_im, input_wires, output_wires, dff_outputs, _, _) = verilog_parser::read_verilog_file(&file_name, true, arithmetic_type); let mut circuit_ptxt = @@ -141,14 +160,15 @@ fn main() { // Client decrypts the output of the circuit start = Instant::now(); println!("Encrypted Evaluation:"); - EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose); + let decrypted_outputs = EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose); + verilog_parser::write_output_wires(outputs_filename, &decrypted_outputs); println!( "Decryption done in {} seconds.", start.elapsed().as_secs_f64() ); } else { let arithmetic_type = "bool"; - let input_wire_map = get_input_wire_map(wire_file, arithmetic_type); + let input_wire_map = get_input_wire_map(inputs_filename, arithmetic_type); let (gates_set, wire_map_im, input_wires, output_wires, dff_outputs, has_luts, _) = verilog_parser::read_verilog_file(&file_name, false, arithmetic_type); let is_sequential = dff_outputs.len() > 1; @@ -226,7 +246,8 @@ fn main() { // Client decrypts the output of the circuit start = Instant::now(); println!("Encrypted Evaluation:"); - EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose); + let decrypted_outputs = EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose); + verilog_parser::write_output_wires(outputs_filename, &decrypted_outputs); println!( "Decryption done in {} seconds.", start.elapsed().as_secs_f64() @@ -271,7 +292,8 @@ fn main() { // Client decrypts the output of the circuit start = Instant::now(); println!("Encrypted Evaluation:"); - EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose); + let decrypted_outputs = EvalCircuit::decrypt_outputs(&circuit, &enc_wire_map, verbose); + verilog_parser::write_output_wires(outputs_filename, &decrypted_outputs); println!( "Decryption done in {} seconds.", start.elapsed().as_secs_f64() diff --git a/src/circuit.rs b/src/circuit.rs index 6469763..09fc174 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -50,7 +50,11 @@ pub trait EvalCircuit { ptxt_type: &str, ) -> HashMap; - fn decrypt_outputs(&self, enc_wire_map: &HashMap, verbose: bool); + fn decrypt_outputs( + &self, + enc_wire_map: &HashMap, + verbose: bool, + ) -> HashMap; } pub struct Circuit<'a> { @@ -527,8 +531,19 @@ impl<'a> EvalCircuit for GateCircuit<'a> { .collect() } - fn decrypt_outputs(&self, enc_wire_map: &HashMap, verbose: bool) { - for (i, output_wire) in self.circuit.output_wires.iter().enumerate() { + fn decrypt_outputs( + &self, + enc_wire_map: &HashMap, + verbose: bool, + ) -> HashMap { + let mut decrypted_outputs = HashMap::new(); + + for output_wire in self.circuit.output_wires { + let decrypted_value = self.client_key.decrypt(&enc_wire_map[output_wire]); + decrypted_outputs.insert(output_wire.clone(), PtxtType::Bool(decrypted_value)); + } + + for (i, (wire, val)) in decrypted_outputs.iter().enumerate() { if i > 10 && !verbose { println!( "{}[!]{} More than ten output_wires, pass `--verbose` to see output.", @@ -537,13 +552,11 @@ impl<'a> EvalCircuit for GateCircuit<'a> { ); break; } else { - println!( - " {}: {}", - output_wire, - self.client_key.decrypt(&enc_wire_map[output_wire]) - ); + println!(" {}: {}", wire, val); } } + + decrypted_outputs } } @@ -651,8 +664,19 @@ impl<'a> EvalCircuit for LutCircuit<'a> { .collect() } - fn decrypt_outputs(&self, enc_wire_map: &HashMap, verbose: bool) { - for (i, output_wire) in self.circuit.output_wires.iter().enumerate() { + fn decrypt_outputs( + &self, + enc_wire_map: &HashMap, + verbose: bool, + ) -> HashMap { + let mut decrypted_outputs = HashMap::new(); + + for output_wire in self.circuit.output_wires { + let decrypted_value = self.client_key.decrypt(&enc_wire_map[output_wire]); + decrypted_outputs.insert(output_wire.clone(), PtxtType::U64(decrypted_value)); + } + + for (i, (wire, val)) in decrypted_outputs.iter().enumerate() { if i > 10 && !verbose { println!( "{}[!]{} More than ten output_wires, pass `--verbose` to see output.", @@ -661,13 +685,11 @@ impl<'a> EvalCircuit for LutCircuit<'a> { ); break; } else { - println!( - " {}: {}", - output_wire, - self.client_key.decrypt(&enc_wire_map[output_wire]) - ); + println!(" {}: {}", wire, val); } } + + decrypted_outputs } } @@ -895,8 +917,19 @@ impl<'a> EvalCircuit for ArithCircuit<'a> { .collect() } - fn decrypt_outputs(&self, enc_wire_map: &HashMap, verbose: bool) { - for (i, output_wire) in self.circuit.output_wires.iter().enumerate() { + fn decrypt_outputs( + &self, + enc_wire_map: &HashMap, + verbose: bool, + ) -> HashMap { + let mut decrypted_outputs = HashMap::new(); + + for output_wire in self.circuit.output_wires { + let decrypted = enc_wire_map[output_wire].decrypt(&self.client_key); + decrypted_outputs.insert(output_wire.clone(), decrypted); + } + + for (i, (wire, val)) in decrypted_outputs.iter().enumerate() { if i > 10 && !verbose { println!( "{}[!]{} More than ten output_wires, pass `--verbose` to see output.", @@ -905,10 +938,11 @@ impl<'a> EvalCircuit for ArithCircuit<'a> { ); break; } else { - let decrypted = enc_wire_map[output_wire].decrypt(&self.client_key); - println!(" {}: {}", output_wire, decrypted); + println!(" {}: {}", wire, val); } } + + decrypted_outputs } } @@ -1021,8 +1055,21 @@ impl<'a> EvalCircuit for HighPrecisionLutCircuit<'a> { .collect() } - fn decrypt_outputs(&self, enc_wire_map: &HashMap, verbose: bool) { - for (i, output_wire) in self.circuit.output_wires.iter().enumerate() { + fn decrypt_outputs( + &self, + enc_wire_map: &HashMap, + verbose: bool, + ) -> HashMap { + let mut decrypted_outputs = HashMap::new(); + + for output_wire in self.circuit.output_wires { + let decrypted = self + .client_key + .decrypt_one_block(&enc_wire_map[output_wire]); + decrypted_outputs.insert(output_wire.clone(), PtxtType::U64(decrypted)); + } + + for (i, (wire, val)) in decrypted_outputs.iter().enumerate() { if i > 10 && !verbose { println!( "{}[!]{} More than ten output_wires, pass `--verbose` to see output.", @@ -1031,14 +1078,11 @@ impl<'a> EvalCircuit for HighPrecisionLutCircuit<'a> { ); break; } else { - println!( - " {}: {}", - output_wire, - self.client_key - .decrypt_one_block(&enc_wire_map[output_wire]) - ); + println!(" {}: {}", wire, val); } } + + decrypted_outputs } } diff --git a/src/verilog_parser.rs b/src/verilog_parser.rs index fac1141..67994b6 100644 --- a/src/verilog_parser.rs +++ b/src/verilog_parser.rs @@ -1,7 +1,7 @@ use csv::Reader; use std::collections::{HashMap, HashSet}; use std::fs::File; -use std::io::{BufRead, BufReader}; +use std::io::{BufRead, BufReader, BufWriter, Write}; use termion::color; use crate::gates::{Gate, GateType}; @@ -322,6 +322,38 @@ pub fn read_input_wires(file_name: &str, ptxt_type: &str) -> HashMap, input_map: &HashMap) { + if let Some(file_name) = file_name { + let file = File::create(&file_name).expect("Failed to create CSV file"); + let mut writer = BufWriter::new(file); + + for (input_wire, ptxt_type) in input_map.iter() { + match ptxt_type { + PtxtType::Bool(value) => { + writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record"); + } + PtxtType::U8(value) => { + writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record"); + } + PtxtType::U16(value) => { + writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record"); + } + PtxtType::U32(value) => { + writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record"); + } + PtxtType::U64(value) => { + writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record"); + } + PtxtType::U128(value) => { + writeln!(writer, "{}, {}", input_wire, value).expect("Failed to write record"); + } + PtxtType::None => unreachable!(), + } + } + println!("Decrypted outputs written to {}", file_name); + } +} + #[test] fn test_parser() { let (gates, wire_map, inputs, _, _, _, _) = read_verilog_file(