From f8735676f0591ebfb2a59a2a076c2c3fa80cd5be Mon Sep 17 00:00:00 2001 From: fpgaminer Date: Tue, 25 May 2021 13:23:45 -0700 Subject: [PATCH] scrypt: add parallel feature, memory reduction feature, and max_memory argument --- Cargo.lock | 1 + scrypt/Cargo.toml | 4 +- scrypt/benches/lib.rs | 26 +++++++++ scrypt/src/lib.rs | 131 ++++++++++++++++++++++++++++++++++++++++-- scrypt/src/romix.rs | 56 +++++++++++++++--- scrypt/tests/mod.rs | 59 ++++++++++++++++++- 6 files changed, 263 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3c620532..31ed1026 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -386,6 +386,7 @@ dependencies = [ "password-hash", "pbkdf2", "rand_core", + "rayon", "salsa20", "sha2", ] diff --git a/scrypt/Cargo.toml b/scrypt/Cargo.toml index 08b0d512..82d4d1bf 100644 --- a/scrypt/Cargo.toml +++ b/scrypt/Cargo.toml @@ -16,6 +16,7 @@ base64ct = { version = "1", default-features = false, features = ["alloc"], opti hmac = "0.11" password-hash = { version = "0.2", default-features = false, features = ["rand_core"], optional = true } pbkdf2 = { version = "0.8", default-features = false, path = "../pbkdf2" } +rayon = { version = "1", optional = true } salsa20 = { version = "0.8", default-features = false, features = ["expose-core"] } sha2 = { version = "0.9", default-features = false } @@ -24,9 +25,10 @@ password-hash = { version = "0.2", features = ["rand_core"] } rand_core = { version = "0.6", features = ["std"] } [features] -default = ["simple", "std"] +default = ["simple", "std", "parallel"] simple = ["password-hash", "base64ct"] std = ["password-hash/std"] +parallel = ["rayon"] [package.metadata.docs.rs] all-features = true diff --git a/scrypt/benches/lib.rs b/scrypt/benches/lib.rs index 838a63a3..26bc72aa 100644 --- a/scrypt/benches/lib.rs +++ b/scrypt/benches/lib.rs @@ -16,3 +16,29 @@ pub fn scrypt_15_8_1(bh: &mut Bencher) { test::black_box(&buf); }); } + +#[bench] +pub fn scrypt_parallel_15_8_4(bh: &mut Bencher) { + let password = b"my secure password"; + let salt = b"salty salt"; + let mut buf = [0u8; 32]; + let params = scrypt::Params::new(15, 8, 4).unwrap(); + bh.iter(|| { + scrypt::scrypt_parallel(password, salt, ¶ms, 4 * 1024 * 1024 * 1024, 4, &mut buf) + .unwrap(); + test::black_box(&buf); + }); +} + +#[bench] +pub fn scrypt_15_8_4(bh: &mut Bencher) { + let password = b"my secure password"; + let salt = b"salty salt"; + let mut buf = [0u8; 32]; + let params = scrypt::Params::new(15, 8, 4).unwrap(); + bh.iter(|| { + scrypt::scrypt_parallel(password, salt, ¶ms, 4 * 1024 * 1024 * 1024, 1, &mut buf) + .unwrap(); + test::black_box(&buf); + }); +} diff --git a/scrypt/src/lib.rs b/scrypt/src/lib.rs index a66aea8b..f571ae4f 100644 --- a/scrypt/src/lib.rs +++ b/scrypt/src/lib.rs @@ -75,14 +75,111 @@ pub use crate::simple::{Scrypt, ALG_ID}; /// **WARNING: Make sure to compare this value in constant time!** /// /// # Return -/// `Ok(())` if calculation is succesfull and `Err(InvalidOutputLen)` if +/// `Ok(())` if calculation is successful and `Err(InvalidOutputLen)` if /// `output` does not satisfy the following condition: /// `output.len() > 0 && output.len() <= (2^32 - 1) * 32`. +/// +/// This function only uses a single thread (the current thread) for computation. pub fn scrypt( password: &[u8], salt: &[u8], params: &Params, output: &mut [u8], +) -> Result<(), errors::InvalidOutputLen> { + scrypt_log_f(password, salt, params, 0, 1, output) +} + +/// The scrypt key derivation function that may use multiple threads. +/// +/// # Arguments +/// - `password` - The password to process as a byte vector +/// - `salt` - The salt value to use as a byte vector +/// - `params` - The ScryptParams to use +/// - `max_memory` - The maximum amount of memory to use, in bytes. May use slightly more (on the order of hundreds of bytes). +/// - `num_threads` - The maximum number of threads to use. +/// - `output` - The resulting derived key is returned in this byte vector. +/// **WARNING: Make sure to compare this value in constant time!** +/// +/// # Return +/// `Ok(())` if calculation is successful and `Err(InvalidOutputLen)` if +/// `output` does not satisfy the following condition: +/// `output.len() > 0 && output.len() <= (2^32 - 1) * 32`. +/// +/// The parallel feature must be enabled for this function to use multiple threads. +/// Note that scrypt normally needs 2**log_n * 128 * r * min(num_threads, p) bytes for computation. +/// If max_memory is less than this, this implementation will automatically reduce memory usage. +/// Though this comes at the cost of increased computation. +/// (Note: It's always better to make this trade if it means using more CPU cores) +pub fn scrypt_parallel( + password: &[u8], + salt: &[u8], + params: &Params, + max_memory: usize, + num_threads: usize, + output: &mut [u8], +) -> Result<(), errors::InvalidOutputLen> { + // The checks in the ScryptParams constructor guarantee + // that the following is safe: + let n: usize = 1 << params.log_n; + let r128 = (params.r as usize) * 128; + + // No point in using more than p threads. + let num_threads = num_threads.min(params.p as usize); + + // The optimal log_f is always the one that allows the most cores to run. + // The increase in computation caused by increased log_f is always offset + // by the increased core usage. Thus log_f can be calculated based on + // num_threads and max_mem (assuming num_threads is less than or equal to + // the number of cores). + // So first we calculate how many blocks each thread can allocate. + let mem_per_thread = max_memory / num_threads; + let blocks_per_thread = mem_per_thread / r128; + + if blocks_per_thread == 0 { + // TODO: Return error + panic!("Not enough memory"); + } + + // Now log_f is calculated by determining how far right we need to shift n + // to be less than or equal to blocks_per_thread. + let possible_log_f = blocks_per_thread + .leading_zeros() + .saturating_sub(n.leading_zeros()); + + // Rounding up. + let log_f = if (n >> possible_log_f) > blocks_per_thread { + // The checked_add should never fail. + possible_log_f.checked_add(1).expect("overflow") + } else { + possible_log_f + }; + + scrypt_log_f(password, salt, params, log_f, num_threads, output) +} + +/// The scrypt key derivation function that accepts the raw log_f parameter. +/// +/// # Arguments +/// - `password` - The password to process as a byte vector +/// - `salt` - The salt value to use as a byte vector +/// - `params` - The ScryptParams to use +/// - `log_f` - A factor that reduces memory usage at the cost of increased computation; must be less than or equal to params.log_n +/// - `num_threads` - The maximum number of threads to use. +/// - `output` - The resulting derived key is returned in this byte vector. +/// **WARNING: Make sure to compare this value in constant time!** +/// +/// # Return +/// `Ok(())` if calculation is successful and `Err(InvalidOutputLen)` if +/// `output` does not satisfy the following condition: +/// `output.len() > 0 && output.len() <= (2^32 - 1) * 32`. +#[doc(hidden)] +pub fn scrypt_log_f( + password: &[u8], + salt: &[u8], + params: &Params, + log_f: u32, + num_threads: usize, + output: &mut [u8], ) -> Result<(), errors::InvalidOutputLen> { // This check required by Scrypt: // check output.len() > 0 && output.len() <= (2^32 - 1) * 32 @@ -90,6 +187,9 @@ pub fn scrypt( return Err(errors::InvalidOutputLen); } + // log_f must be less than or equal to log_n, or else V will be 0 bytes. + assert!(log_f <= (params.log_n as u32)); + // The checks in the ScryptParams constructor guarantee // that the following is safe: let n = 1 << params.log_n; @@ -97,14 +197,35 @@ pub fn scrypt( let pr128 = (params.p as usize) * r128; let nr128 = n * r128; + // B is a set of `p` blocks of data, each `r128` in length. let mut b = vec![0u8; pr128]; pbkdf2::>(&password, salt, 1, &mut b); - let mut v = vec![0u8; nr128]; - let mut t = vec![0u8; r128]; + #[cfg(feature = "parallel")] + { + use rayon::prelude::*; + + // Each chunk of B can be operated on in parallel. Rayon is used to + // distribute that work across the available threads. + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .build() + .expect("Unable to build rayon::ThreadPool"); + pool.install(|| { + b.par_chunks_exact_mut(r128).for_each_init( + || (vec![0u8; nr128 >> log_f], vec![0u8; r128 * 2]), + |(v, t), chunk| romix::scrypt_ro_mix(chunk, v, t, n, log_f), + ); + }); + } + #[cfg(not(feature = "parallel"))] + { + let mut v = vec![0u8; nr128 >> log_f]; + let mut t = vec![0u8; r128 * 2]; - for chunk in &mut b.chunks_mut(r128) { - romix::scrypt_ro_mix(chunk, &mut v, &mut t, n); + for chunk in b.chunks_exact_mut(r128) { + romix::scrypt_ro_mix(chunk, &mut v, &mut t, n, log_f); + } } pbkdf2::>(&password, &b, 1, output); diff --git a/scrypt/src/romix.rs b/scrypt/src/romix.rs index 91b44e54..99c6dc25 100644 --- a/scrypt/src/romix.rs +++ b/scrypt/src/romix.rs @@ -4,12 +4,16 @@ use core::convert::TryInto; type Salsa20_8 = salsa20::Core; /// Execute the ROMix operation in-place. -/// b - the data to operate on -/// v - a temporary variable to store the vector V -/// t - a temporary variable to store the result of the xor +/// b - the data to operate on; len must be a multiple of 128 +/// v - a temporary variable to store the vector V; len must be (n >> log_f) * b.len() +/// t - a temporary variable; len must be b.len() * 2 /// n - the scrypt parameter N +/// log_f - a factor that reduces memory usage at the cost of computation; must always be less than or equal to log_n +/// To get a sense of how log_f works, the following formula calculates the total number +/// of operations performed for a given n and log_f: +/// ops(n, log_f) = 2 * n + 0.5 * n * (2**log_f - 1) #[allow(clippy::many_single_char_names)] -pub(crate) fn scrypt_ro_mix(b: &mut [u8], v: &mut [u8], t: &mut [u8], n: usize) { +pub(crate) fn scrypt_ro_mix(b: &mut [u8], v: &mut [u8], t: &mut [u8], n: usize, log_f: u32) { fn integerify(x: &[u8], n: usize) -> usize { // n is a power of 2, so n - 1 gives us a bitmask that we can use to perform a calculation // mod n using a simple bitwise and. @@ -22,16 +26,54 @@ pub(crate) fn scrypt_ro_mix(b: &mut [u8], v: &mut [u8], t: &mut [u8], n: usize) } let len = b.len(); + let (t1, t2) = t.split_at_mut(len); for chunk in v.chunks_mut(len) { chunk.copy_from_slice(b); - scrypt_block_mix(chunk, b); + + // Store 1 out of every 2**log_f values, so at 0 store every value, at 1 store every other value, etc. + if log_f == 0 { + scrypt_block_mix(chunk, b); + } else { + for _ in 0..((1 << log_f) >> 1) { + scrypt_block_mix(b, t1); + scrypt_block_mix(t1, b); + } + } } + let f_mask = (1 << log_f) - 1; + for _ in 0..n { let j = integerify(b, n); - xor(b, &v[j * len..(j + 1) * len], t); - scrypt_block_mix(t, b); + // Shift by log_f to get the nearest available stored block, rounded down. + let chunk = &v[(j >> log_f) * len..((j >> log_f) + 1) * len]; + + // When log_f > 0 we need to hash the fetched block to re-compute the hash of our + // desired block. + let n_hashes = j & f_mask; + + for i in 0..n_hashes { + if i == 0 { + scrypt_block_mix(chunk, t1); + } else if i & 1 == 1 { + scrypt_block_mix(t1, t2); + } else { + scrypt_block_mix(t2, t1); + } + } + + // Finally we xor and mix like usual, but need to use the right temporary variables. + if n_hashes == 0 { + xor(b, chunk, t1); + scrypt_block_mix(t1, b); + } else if n_hashes & 1 == 0 { + xor(b, t2, t1); + scrypt_block_mix(t1, b); + } else { + xor(b, t1, t2); + scrypt_block_mix(t2, b); + } } } diff --git a/scrypt/tests/mod.rs b/scrypt/tests/mod.rs index 59dd0893..8ad4842c 100644 --- a/scrypt/tests/mod.rs +++ b/scrypt/tests/mod.rs @@ -1,4 +1,4 @@ -use scrypt::{scrypt, Params}; +use scrypt::{scrypt, scrypt_log_f, scrypt_parallel, Params}; #[cfg(feature = "simple")] use { @@ -81,6 +81,63 @@ fn test_scrypt() { } } +/// Tests that scrypt_parallel works correctly, even when max_memory is small. +#[test] +fn test_scrypt_parallel() { + let tests = tests(); + for t in tests.iter() { + let mut result = vec![0u8; t.expected.len()]; + let params = Params::new(t.log_n, t.r, t.p).unwrap(); + scrypt_parallel( + t.password.as_bytes(), + t.salt.as_bytes(), + ¶ms, + 1024 * 1024, + 4, + &mut result, + ) + .unwrap(); + assert!(result == t.expected); + } + + for t in tests.iter() { + let mut result = vec![0u8; t.expected.len()]; + let params = Params::new(t.log_n, t.r, t.p).unwrap(); + scrypt_parallel( + t.password.as_bytes(), + t.salt.as_bytes(), + ¶ms, + 4 * 1024 * 1024 * 1024, + 4, + &mut result, + ) + .unwrap(); + assert!(result == t.expected); + } +} + +/// Tests various log_f values to ensure implementation is correct. +#[test] +fn test_scrypt_log_f() { + let tests = tests(); + for log_f in 0..2 { + for t in tests.iter() { + let mut result = vec![0u8; t.expected.len()]; + let params = Params::new(t.log_n, t.r, t.p).unwrap(); + scrypt_log_f( + t.password.as_bytes(), + t.salt.as_bytes(), + ¶ms, + log_f, + 1, + &mut result, + ) + .unwrap(); + assert!(result == t.expected); + } + } +} + /// Test vector from passlib: /// #[cfg(feature = "simple")]