Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scrypt: Adds parallel feature and max_memory argument #178

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion scrypt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ base64ct = { version = "1", default-features = false, features = ["alloc"], opti
hmac = "0.11"
password-hash = { version = "0.2", default-features = false, features = ["rand_core"], optional = true }
pbkdf2 = { version = "0.8", default-features = false, path = "../pbkdf2" }
rayon = { version = "1", optional = true }
salsa20 = { version = "0.8", default-features = false, features = ["expose-core"] }
sha2 = { version = "0.9", default-features = false }

Expand All @@ -24,9 +25,10 @@ password-hash = { version = "0.2", features = ["rand_core"] }
rand_core = { version = "0.6", features = ["std"] }

[features]
default = ["simple", "std"]
default = ["simple", "std", "parallel"]
simple = ["password-hash", "base64ct"]
std = ["password-hash/std"]
parallel = ["rayon"]

[package.metadata.docs.rs]
all-features = true
Expand Down
26 changes: 26 additions & 0 deletions scrypt/benches/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,29 @@ pub fn scrypt_15_8_1(bh: &mut Bencher) {
test::black_box(&buf);
});
}

#[bench]
pub fn scrypt_parallel_15_8_4(bh: &mut Bencher) {
let password = b"my secure password";
let salt = b"salty salt";
let mut buf = [0u8; 32];
let params = scrypt::Params::new(15, 8, 4).unwrap();
bh.iter(|| {
scrypt::scrypt_parallel(password, salt, &params, 4 * 1024 * 1024 * 1024, 4, &mut buf)
.unwrap();
test::black_box(&buf);
});
}

#[bench]
pub fn scrypt_15_8_4(bh: &mut Bencher) {
let password = b"my secure password";
let salt = b"salty salt";
let mut buf = [0u8; 32];
let params = scrypt::Params::new(15, 8, 4).unwrap();
bh.iter(|| {
scrypt::scrypt_parallel(password, salt, &params, 4 * 1024 * 1024 * 1024, 1, &mut buf)
.unwrap();
test::black_box(&buf);
});
}
131 changes: 126 additions & 5 deletions scrypt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,36 +75,157 @@ pub use crate::simple::{Scrypt, ALG_ID};
/// **WARNING: Make sure to compare this value in constant time!**
///
/// # Return
/// `Ok(())` if calculation is succesfull and `Err(InvalidOutputLen)` if
/// `Ok(())` if calculation is successful and `Err(InvalidOutputLen)` if
/// `output` does not satisfy the following condition:
/// `output.len() > 0 && output.len() <= (2^32 - 1) * 32`.
///
/// This function only uses a single thread (the current thread) for computation.
pub fn scrypt(
password: &[u8],
salt: &[u8],
params: &Params,
output: &mut [u8],
) -> Result<(), errors::InvalidOutputLen> {
scrypt_log_f(password, salt, params, 0, 1, output)
}

/// The scrypt key derivation function that may use multiple threads.
///
/// # Arguments
/// - `password` - The password to process as a byte vector
/// - `salt` - The salt value to use as a byte vector
/// - `params` - The ScryptParams to use
/// - `max_memory` - The maximum amount of memory to use, in bytes. May use slightly more (on the order of hundreds of bytes).
/// - `num_threads` - The maximum number of threads to use.
/// - `output` - The resulting derived key is returned in this byte vector.
/// **WARNING: Make sure to compare this value in constant time!**
///
/// # Return
/// `Ok(())` if calculation is successful and `Err(InvalidOutputLen)` if
/// `output` does not satisfy the following condition:
/// `output.len() > 0 && output.len() <= (2^32 - 1) * 32`.
///
/// The parallel feature must be enabled for this function to use multiple threads.
/// Note that scrypt normally needs 2**log_n * 128 * r * min(num_threads, p) bytes for computation.
/// If max_memory is less than this, this implementation will automatically reduce memory usage.
/// Though this comes at the cost of increased computation.
/// (Note: It's always better to make this trade if it means using more CPU cores)
pub fn scrypt_parallel(
password: &[u8],
salt: &[u8],
params: &Params,
max_memory: usize,
num_threads: usize,
output: &mut [u8],
) -> Result<(), errors::InvalidOutputLen> {
// The checks in the ScryptParams constructor guarantee
// that the following is safe:
let n: usize = 1 << params.log_n;
let r128 = (params.r as usize) * 128;

// No point in using more than p threads.
let num_threads = num_threads.min(params.p as usize);

// The optimal log_f is always the one that allows the most cores to run.
// The increase in computation caused by increased log_f is always offset
// by the increased core usage. Thus log_f can be calculated based on
// num_threads and max_mem (assuming num_threads is less than or equal to
// the number of cores).
// So first we calculate how many blocks each thread can allocate.
let mem_per_thread = max_memory / num_threads;
let blocks_per_thread = mem_per_thread / r128;

if blocks_per_thread == 0 {
// TODO: Return error
panic!("Not enough memory");
}

// Now log_f is calculated by determining how far right we need to shift n
// to be less than or equal to blocks_per_thread.
let possible_log_f = blocks_per_thread
.leading_zeros()
.saturating_sub(n.leading_zeros());

// Rounding up.
let log_f = if (n >> possible_log_f) > blocks_per_thread {
// The checked_add should never fail.
possible_log_f.checked_add(1).expect("overflow")
} else {
possible_log_f
};

scrypt_log_f(password, salt, params, log_f, num_threads, output)
}

/// The scrypt key derivation function that accepts the raw log_f parameter.
///
/// # Arguments
/// - `password` - The password to process as a byte vector
/// - `salt` - The salt value to use as a byte vector
/// - `params` - The ScryptParams to use
/// - `log_f` - A factor that reduces memory usage at the cost of increased computation; must be less than or equal to params.log_n
/// - `num_threads` - The maximum number of threads to use.
/// - `output` - The resulting derived key is returned in this byte vector.
/// **WARNING: Make sure to compare this value in constant time!**
///
/// # Return
/// `Ok(())` if calculation is successful and `Err(InvalidOutputLen)` if
/// `output` does not satisfy the following condition:
/// `output.len() > 0 && output.len() <= (2^32 - 1) * 32`.
#[doc(hidden)]
pub fn scrypt_log_f(
password: &[u8],
salt: &[u8],
params: &Params,
log_f: u32,
num_threads: usize,
output: &mut [u8],
) -> Result<(), errors::InvalidOutputLen> {
// This check required by Scrypt:
// check output.len() > 0 && output.len() <= (2^32 - 1) * 32
if output.is_empty() || output.len() / 32 > 0xffff_ffff {
return Err(errors::InvalidOutputLen);
}

// log_f must be less than or equal to log_n, or else V will be 0 bytes.
assert!(log_f <= (params.log_n as u32));

// The checks in the ScryptParams constructor guarantee
// that the following is safe:
let n = 1 << params.log_n;
let r128 = (params.r as usize) * 128;
let pr128 = (params.p as usize) * r128;
let nr128 = n * r128;

// B is a set of `p` blocks of data, each `r128` in length.
let mut b = vec![0u8; pr128];
pbkdf2::<Hmac<Sha256>>(&password, salt, 1, &mut b);

let mut v = vec![0u8; nr128];
let mut t = vec![0u8; r128];
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;

// Each chunk of B can be operated on in parallel. Rayon is used to
// distribute that work across the available threads.
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.expect("Unable to build rayon::ThreadPool");
pool.install(|| {
b.par_chunks_exact_mut(r128).for_each_init(
|| (vec![0u8; nr128 >> log_f], vec![0u8; r128 * 2]),
|(v, t), chunk| romix::scrypt_ro_mix(chunk, v, t, n, log_f),
);
});
}
#[cfg(not(feature = "parallel"))]
{
let mut v = vec![0u8; nr128 >> log_f];
let mut t = vec![0u8; r128 * 2];

for chunk in &mut b.chunks_mut(r128) {
romix::scrypt_ro_mix(chunk, &mut v, &mut t, n);
for chunk in b.chunks_exact_mut(r128) {
romix::scrypt_ro_mix(chunk, &mut v, &mut t, n, log_f);
}
}

pbkdf2::<Hmac<Sha256>>(&password, &b, 1, output);
Expand Down
56 changes: 49 additions & 7 deletions scrypt/src/romix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ use core::convert::TryInto;
type Salsa20_8 = salsa20::Core<salsa20::R8>;

/// Execute the ROMix operation in-place.
/// b - the data to operate on
/// v - a temporary variable to store the vector V
/// t - a temporary variable to store the result of the xor
/// b - the data to operate on; len must be a multiple of 128
/// v - a temporary variable to store the vector V; len must be (n >> log_f) * b.len()
/// t - a temporary variable; len must be b.len() * 2
/// n - the scrypt parameter N
/// log_f - a factor that reduces memory usage at the cost of computation; must always be less than or equal to log_n
/// To get a sense of how log_f works, the following formula calculates the total number
/// of operations performed for a given n and log_f:
/// ops(n, log_f) = 2 * n + 0.5 * n * (2**log_f - 1)
#[allow(clippy::many_single_char_names)]
pub(crate) fn scrypt_ro_mix(b: &mut [u8], v: &mut [u8], t: &mut [u8], n: usize) {
pub(crate) fn scrypt_ro_mix(b: &mut [u8], v: &mut [u8], t: &mut [u8], n: usize, log_f: u32) {
fn integerify(x: &[u8], n: usize) -> usize {
// n is a power of 2, so n - 1 gives us a bitmask that we can use to perform a calculation
// mod n using a simple bitwise and.
Expand All @@ -22,16 +26,54 @@ pub(crate) fn scrypt_ro_mix(b: &mut [u8], v: &mut [u8], t: &mut [u8], n: usize)
}

let len = b.len();
let (t1, t2) = t.split_at_mut(len);

for chunk in v.chunks_mut(len) {
chunk.copy_from_slice(b);
scrypt_block_mix(chunk, b);

// Store 1 out of every 2**log_f values, so at 0 store every value, at 1 store every other value, etc.
if log_f == 0 {
scrypt_block_mix(chunk, b);
} else {
for _ in 0..((1 << log_f) >> 1) {
scrypt_block_mix(b, t1);
scrypt_block_mix(t1, b);
}
}
}

let f_mask = (1 << log_f) - 1;

for _ in 0..n {
let j = integerify(b, n);
xor(b, &v[j * len..(j + 1) * len], t);
scrypt_block_mix(t, b);
// Shift by log_f to get the nearest available stored block, rounded down.
let chunk = &v[(j >> log_f) * len..((j >> log_f) + 1) * len];

// When log_f > 0 we need to hash the fetched block to re-compute the hash of our
// desired block.
let n_hashes = j & f_mask;

for i in 0..n_hashes {
if i == 0 {
scrypt_block_mix(chunk, t1);
} else if i & 1 == 1 {
scrypt_block_mix(t1, t2);
} else {
scrypt_block_mix(t2, t1);
}
}

// Finally we xor and mix like usual, but need to use the right temporary variables.
if n_hashes == 0 {
xor(b, chunk, t1);
scrypt_block_mix(t1, b);
} else if n_hashes & 1 == 0 {
xor(b, t2, t1);
scrypt_block_mix(t1, b);
} else {
xor(b, t1, t2);
scrypt_block_mix(t2, b);
}
}
}

Expand Down
59 changes: 58 additions & 1 deletion scrypt/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use scrypt::{scrypt, Params};
use scrypt::{scrypt, scrypt_log_f, scrypt_parallel, Params};

#[cfg(feature = "simple")]
use {
Expand Down Expand Up @@ -81,6 +81,63 @@ fn test_scrypt() {
}
}

/// Tests that scrypt_parallel works correctly, even when max_memory is small.
#[test]
fn test_scrypt_parallel() {
let tests = tests();
for t in tests.iter() {
let mut result = vec![0u8; t.expected.len()];
let params = Params::new(t.log_n, t.r, t.p).unwrap();
scrypt_parallel(
t.password.as_bytes(),
t.salt.as_bytes(),
&params,
1024 * 1024,
4,
&mut result,
)
.unwrap();
assert!(result == t.expected);
}

for t in tests.iter() {
let mut result = vec![0u8; t.expected.len()];
let params = Params::new(t.log_n, t.r, t.p).unwrap();
scrypt_parallel(
t.password.as_bytes(),
t.salt.as_bytes(),
&params,
4 * 1024 * 1024 * 1024,
4,
&mut result,
)
.unwrap();
assert!(result == t.expected);
}
}

/// Tests various log_f values to ensure implementation is correct.
#[test]
fn test_scrypt_log_f() {
let tests = tests();
for log_f in 0..2 {
for t in tests.iter() {
let mut result = vec![0u8; t.expected.len()];
let params = Params::new(t.log_n, t.r, t.p).unwrap();
scrypt_log_f(
t.password.as_bytes(),
t.salt.as_bytes(),
&params,
log_f,
1,
&mut result,
)
.unwrap();
assert!(result == t.expected);
}
}
}

/// Test vector from passlib:
/// <https://passlib.readthedocs.io/en/stable/lib/passlib.hash.scrypt.html>
#[cfg(feature = "simple")]
Expand Down