Skip to content

Commit

Permalink
Merge remote-tracking branch 'ipa/main' into parallel-decryption-2
Browse files Browse the repository at this point in the history
  • Loading branch information
andyleiserson committed Dec 20, 2024
2 parents 6a65171 + 291efab commit 5511e63
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 42 deletions.
22 changes: 16 additions & 6 deletions ipa-core/src/cli/crypto/hybrid_decrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,14 @@ mod tests {

let output_dir = tempdir().unwrap();
let network_file = hybrid_sample_data::test_keys().network_config();
HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path())
.encrypt()
.unwrap();
HybridEncryptArgs::new(
input_file.path(),
output_dir.path(),
network_file.path(),
false,
)
.encrypt()
.unwrap();

let decrypt_output = output_dir.path().join("output");
let enc1 = output_dir.path().join("DOES_NOT_EXIST.enc");
Expand Down Expand Up @@ -258,9 +263,14 @@ mod tests {

let network_file = hybrid_sample_data::test_keys().network_config();
let output_dir = tempdir().unwrap();
HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path())
.encrypt()
.unwrap();
HybridEncryptArgs::new(
input_file.path(),
output_dir.path(),
network_file.path(),
false,
)
.encrypt()
.unwrap();

let decrypt_output = output_dir.path().join("output");
let enc1 = output_dir.path().join("helper1.enc");
Expand Down
148 changes: 120 additions & 28 deletions ipa-core/src/cli/crypto/hybrid_encrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,38 @@ pub struct HybridEncryptArgs {
/// Path to helper network configuration file
#[arg(long)]
network: PathBuf,
/// a flag to produce length delimited binary instead of newline delimited hex
#[arg(long)]
length_delimited: bool,
}

#[derive(Copy, Clone)]
enum FileFormat {
LengthDelimitedBinary,
NewlineDelimitedHex,
}

impl HybridEncryptArgs {
#[must_use]
pub fn new(input_file: &Path, output_dir: &Path, network: &Path) -> Self {
pub fn new(
input_file: &Path,
output_dir: &Path,
network: &Path,
length_delimited: bool,
) -> Self {
Self {
input_file: input_file.to_path_buf(),
output_dir: output_dir.to_path_buf(),
network: network.to_path_buf(),
length_delimited,
}
}

fn file_format(&self) -> FileFormat {
if self.length_delimited {
FileFormat::LengthDelimitedBinary
} else {
FileFormat::NewlineDelimitedHex
}
}

Expand Down Expand Up @@ -89,7 +112,8 @@ impl HybridEncryptArgs {
panic!("could not load network file")
};

let mut worker_pool = ReportWriter::new(key_registries, &self.output_dir);
let mut worker_pool =
ReportWriter::new(key_registries, &self.output_dir, self.file_format());
for (report_id, record) in input.iter::<TestHybridRecord>().enumerate() {
worker_pool.submit(report_id, record.share())?;
}
Expand Down Expand Up @@ -118,6 +142,7 @@ impl EncryptorPool {
thread_count: usize,
file_writer: [SyncSender<EncryptorOutput>; 3],
key_registries: [KeyRegistry<PublicKeyOnly>; 3],
file_format: FileFormat,
) -> Self {
Self {
pool: (0..thread_count)
Expand All @@ -132,11 +157,23 @@ impl EncryptorPool {
.spawn(move || {
for (i, helper_id, report) in rx {
let key_registry = &key_registries[helper_id];
let output = report.encrypt(
DEFAULT_KEY_ID,
key_registry,
&mut thread_rng(),
)?;
let mut output =
Vec::with_capacity(usize::from(report.encrypted_len() + 2));
match file_format {
FileFormat::NewlineDelimitedHex => report.encrypt_to(
DEFAULT_KEY_ID,
key_registry,
&mut thread_rng(),
&mut output,
)?,
FileFormat::LengthDelimitedBinary => report
.delimited_encrypt_to(
DEFAULT_KEY_ID,
key_registry,
&mut thread_rng(),
&mut output,
)?,
}
file_writer[helper_id].send((i, output))?;
}

Expand Down Expand Up @@ -178,7 +215,11 @@ struct ReportWriter {
}

impl ReportWriter {
pub fn new(key_registries: [KeyRegistry<PublicKeyOnly>; 3], output_dir: &Path) -> Self {
pub fn new(
key_registries: [KeyRegistry<PublicKeyOnly>; 3],
output_dir: &Path,
file_format: FileFormat,
) -> Self {
// create 3 worker threads to write data into 3 files
let workers = array::from_fn(|i| {
let output_filename = format!("helper{}.enc", i + 1);
Expand All @@ -188,12 +229,13 @@ impl ReportWriter {
.open(output_dir.join(&output_filename))
.unwrap_or_else(|e| panic!("unable write to {:?}. {}", &output_filename, e));

FileWriteWorker::new(file)
FileWriteWorker::new(file, file_format)
});
let encryptor_pool = EncryptorPool::with_worker_threads(
num_cpus::get(),
workers.each_ref().map(|x| x.sender.clone()),
key_registries,
file_format,
);

Self {
Expand Down Expand Up @@ -239,17 +281,26 @@ struct FileWriteWorker {
}

impl FileWriteWorker {
pub fn new(file: File) -> Self {
pub fn new(file: File, file_format: FileFormat) -> Self {
fn write_report<W: Write>(
writer: &mut W,
report: &[u8],
file_format: FileFormat,
) -> Result<(), BoxError> {
match file_format {
FileFormat::LengthDelimitedBinary => {
FileWriteWorker::write_report_length_delimited_binary(writer, report)
}
FileFormat::NewlineDelimitedHex => {
FileWriteWorker::write_report_newline_delimited_hex(writer, report)
}
}
}

let (tx, rx) = std::sync::mpsc::sync_channel(65535);
Self {
sender: tx,
handle: thread::spawn(move || {
fn write_report<W: Write>(writer: &mut W, report: &[u8]) -> Result<(), BoxError> {
let hex_output = hex::encode(report);
writeln!(writer, "{hex_output}")?;
Ok(())
}

// write low watermark. All reports below this line have been written
let mut lw = 0;
let mut pending_reports = BTreeMap::new();
Expand All @@ -271,7 +322,7 @@ impl FileWriteWorker {
"Internal error: received a duplicate report {report_id}"
);
while let Some(report) = pending_reports.remove(&lw) {
write_report(&mut writer, &report)?;
write_report(&mut writer, &report, file_format)?;
lw += 1;
if lw % 1_000_000 == 0 {
tracing::info!("Encrypted {}M reports", lw / 1_000_000);
Expand All @@ -282,6 +333,23 @@ impl FileWriteWorker {
}),
}
}

fn write_report_newline_delimited_hex<W: Write>(
writer: &mut W,
report: &[u8],
) -> Result<(), BoxError> {
let hex_output = hex::encode(report);
writeln!(writer, "{hex_output}")?;
Ok(())
}

fn write_report_length_delimited_binary<W: Write>(
writer: &mut W,
report: &[u8],
) -> Result<(), BoxError> {
writer.write_all(report)?;
Ok(())
}
}

#[cfg(all(test, unit_test))]
Expand Down Expand Up @@ -334,12 +402,26 @@ mod tests {
}
input_file.flush().unwrap();

let output_dir = tempdir().unwrap();
let output_dir_1 = tempdir().unwrap();
let output_dir_2 = tempdir().unwrap();
let network_file = sample_data::test_keys().network_config();

HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path())
.encrypt()
.unwrap();
HybridEncryptArgs::new(
input_file.path(),
output_dir_1.path(),
network_file.path(),
false,
)
.encrypt()
.unwrap();
HybridEncryptArgs::new(
input_file.path(),
output_dir_2.path(),
network_file.path(),
true,
)
.encrypt()
.unwrap();
}

#[test]
Expand All @@ -350,7 +432,7 @@ mod tests {
let output_dir = tempdir().unwrap();
let network_dir = tempdir().unwrap();
let network_file = network_dir.path().join("does_not_exist");
HybridEncryptArgs::new(input_file.path(), output_dir.path(), &network_file)
HybridEncryptArgs::new(input_file.path(), output_dir.path(), &network_file, true)
.encrypt()
.unwrap();
}
Expand All @@ -368,9 +450,14 @@ this is not toml!
let mut network_file = NamedTempFile::new().unwrap();
writeln!(network_file.as_file_mut(), "{network_data}").unwrap();

HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path())
.encrypt()
.unwrap();
HybridEncryptArgs::new(
input_file.path(),
output_dir.path(),
network_file.path(),
true,
)
.encrypt()
.unwrap();
}

#[test]
Expand All @@ -392,8 +479,13 @@ public_key = "cfdbaaff16b30aa8a4ab07eaad2cdd80458208a1317aefbb807e46dce596617e"
let mut network_file = NamedTempFile::new().unwrap();
writeln!(network_file.as_file_mut(), "{network_data}").unwrap();

HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path())
.encrypt()
.unwrap();
HybridEncryptArgs::new(
input_file.path(),
output_dir.path(),
network_file.path(),
true,
)
.encrypt()
.unwrap();
}
}
11 changes: 8 additions & 3 deletions ipa-core/src/cli/crypto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,14 @@ mod tests {
let input = hybrid_sample_data::test_hybrid_data().take(10);
let input_file = hybrid_sample_data::write_csv(input).unwrap();
let network_file = hybrid_sample_data::test_keys().network_config();
HybridEncryptArgs::new(input_file.path(), output_dir.path(), network_file.path())
.encrypt()
.unwrap();
HybridEncryptArgs::new(
input_file.path(),
output_dir.path(),
network_file.path(),
false,
)
.encrypt()
.unwrap();

let decrypt_output = output_dir.path().join("output");
let enc1 = output_dir.path().join("helper1.enc");
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/net/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub enum Error {
status: hyper::StatusCode,
reason: String,
},
#[error("Failed to connect to {dest}: {inner}")]
#[error("Failed to connect to {dest}: {inner:?}")]
ConnectError {
dest: String,
#[source]
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/net/query_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub async fn stream_query_input_from_url(uri: &Uri) -> Result<BodyStream, Error>
HttpsConnectorBuilder::default()
.with_native_roots()
.expect("System truststore is required")
.https_only()
.https_or_http()
.enable_all_versions()
.build(),
);
Expand Down
5 changes: 2 additions & 3 deletions ipa-core/src/net/server/handlers/query/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ mod tests {
#[tokio::test(flavor = "multi_thread")]
async fn input_from_url() {
const QUERY_ID: QueryId = QueryId;
const DATA: &str = "input records";
const DATA: &str = "<input records>";

let server = tiny_http::Server::http("localhost:0").unwrap();
let addr = server.server_addr();
Expand All @@ -124,9 +124,8 @@ mod tests {
.await;

let url = format!(
"http://localhost:{}{}/{QUERY_ID}/input",
"http://localhost:{}/input-data",
addr.to_ip().unwrap().port(),
http_serde::query::BASE_AXUM_PATH,
);
let req = http_serde::query::input::Request::new(QueryInput::FromUrl {
query_id: QUERY_ID,
Expand Down

0 comments on commit 5511e63

Please sign in to comment.