diff --git a/age/CHANGELOG.md b/age/CHANGELOG.md index 3f66a0f3..f77ee549 100644 --- a/age/CHANGELOG.md +++ b/age/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to Rust's notion of to 1.0.0 are beta releases. ## [Unreleased] +### Added +- `age::Decryptor::new_buffered`, which is more efficient for types implementing + `std::io::BufRead` (which includes `&[u8]` slices). +- `impl std::io::BufRead for age::armor::ArmoredReader` ## [0.9.1] - 2022-03-24 ### Added diff --git a/age/benches/parser.rs b/age/benches/parser.rs index 6384351e..e67fb579 100644 --- a/age/benches/parser.rs +++ b/age/benches/parser.rs @@ -30,7 +30,7 @@ fn bench(c: &mut Criterion) { output.write_all(&[]).unwrap(); output.finish().unwrap(); - b.iter(|| Decryptor::new(&encrypted[..])) + b.iter(|| Decryptor::new_buffered(&encrypted[..])) }); } diff --git a/age/benches/throughput.rs b/age/benches/throughput.rs index 2e7e4483..7a3c8e5f 100644 --- a/age/benches/throughput.rs +++ b/age/benches/throughput.rs @@ -70,7 +70,7 @@ fn bench(c: &mut Criterion_) { output.finish().unwrap(); b.iter(|| { - let decryptor = match Decryptor::new(&ct_buf[..]).unwrap() { + let decryptor = match Decryptor::new_buffered(&ct_buf[..]).unwrap() { Decryptor::Recipients(decryptor) => decryptor, _ => panic!(), }; diff --git a/age/src/format.rs b/age/src/format.rs index 7bd499f4..b4759138 100644 --- a/age/src/format.rs +++ b/age/src/format.rs @@ -1,7 +1,7 @@ //! The age file format. use age_core::format::Stanza; -use std::io::{self, Read, Write}; +use std::io::{self, BufRead, Read, Write}; use crate::{ error::DecryptError, @@ -88,6 +88,34 @@ impl Header { } } + pub(crate) fn read_buffered(mut input: R) -> Result { + let mut data = vec![]; + loop { + match read::header(&data) { + Ok((_, mut header)) => { + if let Header::V1(h) = &mut header { + h.encoded_bytes = Some(data); + } + break Ok(header); + } + Err(nom::Err::Incomplete(nom::Needed::Size(_))) => { + // As we have a buffered reader, we can leverage the fact that the + // currently-defined header formats are newline-separated, to more + // efficiently read data for the parser to consume. + if input.read_until(b'\n', &mut data)? == 0 { + break Err(DecryptError::Io(io::Error::new( + io::ErrorKind::UnexpectedEof, + "Incomplete header", + ))); + } + } + Err(_) => { + break Err(DecryptError::InvalidHeader); + } + } + } + } + #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] pub(crate) async fn read_async( diff --git a/age/src/primitives/armor.rs b/age/src/primitives/armor.rs index 1ac5997c..240cca0e 100644 --- a/age/src/primitives/armor.rs +++ b/age/src/primitives/armor.rs @@ -726,6 +726,7 @@ impl ArmoredReader { /// /// Returns the number of bytes read into the buffer, or None if there was no cached /// data. + #[cfg(feature = "async")] fn read_cached_data(&mut self, buf: &mut [u8]) -> Option { if self.byte_start >= self.byte_end { None @@ -820,8 +821,8 @@ impl ArmoredReader { } } -impl Read for ArmoredReader { - fn read(&mut self, mut buf: &mut [u8]) -> io::Result { +impl BufRead for ArmoredReader { + fn fill_buf(&mut self) -> io::Result<&[u8]> { loop { match self.is_armored { None => { @@ -829,60 +830,97 @@ impl Read for ArmoredReader { self.detect_armor()? } Some(false) => { - // Return any leftover data from armor detection - return if let Some(read) = self.read_cached_data(buf) { - Ok(read) - } else { - self.inner.read(buf).map(|read| { - self.data_read += read; - self.count_reader_bytes(read) + break if self.byte_start >= self.byte_end { + self.inner.read(&mut self.byte_buf[..]).map(|read| { + self.byte_start = 0; + self.byte_end = read; + self.count_reader_bytes(read); + &self.byte_buf[..read] }) - }; + } else { + Ok(&self.byte_buf[self.byte_start..self.byte_end]) + } + } + Some(true) => { + break if self.found_end { + Ok(&[]) + } else if self.byte_start >= self.byte_end { + if self.read_next_armor_line()? { + Ok(&[]) + } else { + Ok(&self.byte_buf[self.byte_start..self.byte_end]) + } + } else { + Ok(&self.byte_buf[self.byte_start..self.byte_end]) + } } - Some(true) => break, } } - if self.found_end { - return Ok(0); - } + } - let buf_len = buf.len(); + fn consume(&mut self, amt: usize) { + self.byte_start += amt; + self.data_read += amt; + assert!(self.byte_start <= self.byte_end); + } +} - // Output any remaining bytes from the previous line - if let Some(read) = self.read_cached_data(buf) { - buf = &mut buf[read..]; - } +impl ArmoredReader { + /// Fills `self.byte_buf` with the next line of armored data. + /// + /// Returns `true` if this was the last line. + fn read_next_armor_line(&mut self) -> io::Result { + assert_eq!(self.is_armored, Some(true)); - while !buf.is_empty() { - // Read the next line - self.inner - .read_line(&mut self.line_buf) - .map(|read| self.count_reader_bytes(read))?; - - // Parse the line into bytes - if self.parse_armor_line()? { - // This was the last line! Check for trailing garbage. - loop { - let amt = match self.inner.fill_buf()? { - &[] => break, - buf => { - if buf.iter().any(|b| !b.is_ascii_whitespace()) { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - ArmoredReadError::TrailingGarbage, - )); - } - buf.len() + // Read the next line + self.inner + .read_line(&mut self.line_buf) + .map(|read| self.count_reader_bytes(read))?; + + // Parse the line into bytes + if self.parse_armor_line()? { + // This was the last line! Check for trailing garbage. + loop { + let amt = match self.inner.fill_buf()? { + &[] => break, + buf => { + if buf.iter().any(|b| !b.is_ascii_whitespace()) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + ArmoredReadError::TrailingGarbage, + )); } - }; - self.inner.consume(amt); - } - break; + buf.len() + } + }; + self.inner.consume(amt); } + Ok(true) + } else { + Ok(false) + } + } +} + +impl Read for ArmoredReader { + fn read(&mut self, mut buf: &mut [u8]) -> io::Result { + let buf_len = buf.len(); + + while !buf.is_empty() { + match self.fill_buf()? { + [] => break, + next => { + let read = cmp::min(next.len(), buf.len()); + + if next.len() < buf.len() { + buf[..read].copy_from_slice(next); + } else { + buf.copy_from_slice(&next[..read]); + } - // Output as much as we can of this line - if let Some(read) = self.read_cached_data(buf) { - buf = &mut buf[read..]; + self.consume(read); + buf = &mut buf[read..]; + } } } @@ -1061,49 +1099,9 @@ impl Seek for ArmoredReader { self.inner.read_exact(&mut self.byte_buf[..MIN_ARMOR_LEN])?; self.detect_armor()? } - Some(false) => { - break if self.byte_start >= self.byte_end { - // Map the data read onto the underlying stream. - let start = self.start()?; - let pos = match pos { - SeekFrom::Start(offset) => SeekFrom::Start(start + offset), - // Current and End positions don't need to be shifted. - x => x, - }; - self.inner.seek(pos) - } else { - // We are still inside the first line. - match pos { - SeekFrom::Start(offset) => self.byte_start = offset as usize, - SeekFrom::Current(offset) => { - let res = (self.byte_start as i64) + offset; - if res >= 0 { - self.byte_start = res as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "cannot seek before the start", - )); - } - } - SeekFrom::End(offset) => { - let res = (self.line_buf.len() as i64) + offset; - if res >= 0 { - self.byte_start = res as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "cannot seek before the start", - )); - } - } - } - Ok(self.byte_start as u64) - }; - } - Some(true) => { + Some(armored) => { // Convert the offset into the target position within the data inside - // the armor. + // the (maybe) armor. let start = self.start()?; let target_pos = match pos { SeekFrom::Start(offset) => offset, @@ -1146,6 +1144,15 @@ impl Seek for ArmoredReader { } }; + if !armored { + // We can seek directly on the inner reader. + self.inner.seek(SeekFrom::Start(start + target_pos))?; + self.byte_start = 0; + self.byte_end = 0; + self.data_read = target_pos as usize; + break Ok(self.data_read as u64); + } + // Jump back to the start of the armor data, and then read and drop // until we reach the target position. This is very inefficient, but // as armored files can have arbitrary line endings within the file, diff --git a/age/src/protocol.rs b/age/src/protocol.rs index b5229df5..6c6df056 100644 --- a/age/src/protocol.rs +++ b/age/src/protocol.rs @@ -2,7 +2,7 @@ use age_core::{format::grease_the_joint, secrecy::SecretString}; use rand::{rngs::OsRng, RngCore}; -use std::io::{self, Read, Write}; +use std::io::{self, BufRead, Read, Write}; use crate::{ error::{DecryptError, EncryptError}, @@ -181,6 +181,13 @@ impl Decryptor { /// Attempts to create a decryptor for an age file. /// /// Returns an error if the input does not contain a valid age file. + /// + /// # Performance + /// + /// This constructor will work with any type implementing [`io::Read`], and uses a + /// slower parser and internal buffering to ensure no overreading occurs. Consider + /// using [`Decryptor::new_buffered`] for types implementing `std::io::BufRead`, which + /// includes `&[u8]` slices. pub fn new(mut input: R) -> Result { let header = Header::read(&mut input)?; @@ -194,6 +201,28 @@ impl Decryptor { } } +impl Decryptor { + /// Attempts to create a decryptor for an age file. + /// + /// Returns an error if the input does not contain a valid age file. + /// + /// # Performance + /// + /// This constructor is more performant than [`Decryptor::new`] for types implementing + /// [`io::BufRead`], which includes `&[u8]` slices. + pub fn new_buffered(mut input: R) -> Result { + let header = Header::read_buffered(&mut input)?; + + match header { + Header::V1(v1_header) => { + let nonce = Nonce::read(&mut input)?; + Decryptor::from_v1_header(input, v1_header, nonce) + } + Header::Unknown(_) => Err(DecryptError::UnknownFormat), + } + } +} + #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] impl Decryptor { diff --git a/age/tests/testkit.rs b/age/tests/testkit.rs index 66e0df6c..afb85edd 100644 --- a/age/tests/testkit.rs +++ b/age/tests/testkit.rs @@ -264,6 +264,145 @@ fn testkit(filename: &str) { #[test_case("x25519_not_canonical_body")] #[test_case("x25519_not_canonical_share")] #[test_case("x25519_short_share")] +fn testkit_buffered(filename: &str) { + let testfile = TestFile::parse(filename); + let comment = format_testkit_comment(&testfile); + + match Decryptor::new_buffered(ArmoredReader::new(&testfile.age_file[..])).and_then( + |d| match d { + Decryptor::Recipients(d) => { + let identities = get_testkit_identities(filename, &testfile); + d.decrypt(identities.iter().map(|i| i as &dyn Identity)) + } + Decryptor::Passphrase(d) => { + let passphrase = get_testkit_passphrase(&testfile, &comment); + d.decrypt(&passphrase, Some(16)) + } + }, + ) { + Ok(mut r) => { + let mut payload = vec![]; + let res = io::Read::read_to_end(&mut r, &mut payload); + check_decrypt_success(filename, testfile, &comment, res, &payload); + } + Err(e) => check_decrypt_error(filename, testfile, e), + } +} + +#[test_case("armor")] +#[test_case("armor_crlf")] +#[test_case("armor_empty_line_begin")] +#[test_case("armor_empty_line_end")] +#[test_case("armor_eol_between_padding")] +#[test_case("armor_full_last_line")] +#[test_case("armor_garbage_encoded")] +#[test_case("armor_garbage_leading")] +#[test_case("armor_garbage_trailing")] +#[test_case("armor_header_crlf")] +#[test_case("armor_headers")] +#[test_case("armor_invalid_character_header")] +#[test_case("armor_invalid_character_payload")] +#[test_case("armor_long_line")] +#[test_case("armor_lowercase")] +#[test_case("armor_no_end_line")] +#[test_case("armor_no_eol")] +#[test_case("armor_no_match")] +#[test_case("armor_no_padding")] +#[test_case("armor_not_canonical")] +#[test_case("armor_pgp_checksum")] +#[test_case("armor_short_line")] +#[test_case("armor_whitespace_begin")] +#[test_case("armor_whitespace_end")] +#[test_case("armor_whitespace_eol")] +#[test_case("armor_whitespace_last_line")] +#[test_case("armor_whitespace_line_start")] +#[test_case("armor_whitespace_outside")] +#[test_case("armor_wrong_type")] +#[test_case("header_crlf")] +#[test_case("hmac_bad")] +#[test_case("hmac_extra_space")] +#[test_case("hmac_garbage")] +#[test_case("hmac_missing")] +#[test_case("hmac_no_space")] +#[test_case("hmac_not_canonical")] +#[test_case("hmac_trailing_space")] +#[test_case("hmac_truncated")] +#[test_case("scrypt")] +#[test_case("scrypt_and_x25519")] +#[test_case("scrypt_bad_tag")] +#[test_case("scrypt_double")] +#[test_case("scrypt_extra_argument")] +#[test_case("scrypt_long_file_key")] +#[test_case("scrypt_no_match")] +#[test_case("scrypt_not_canonical_body")] +#[test_case("scrypt_not_canonical_salt")] +#[test_case("scrypt_salt_long")] +#[test_case("scrypt_salt_missing")] +#[test_case("scrypt_salt_short")] +#[test_case("scrypt_uppercase")] +#[test_case("scrypt_work_factor_23")] +#[test_case("scrypt_work_factor_hex")] +#[test_case("scrypt_work_factor_leading_garbage")] +#[test_case("scrypt_work_factor_leading_plus")] +#[test_case("scrypt_work_factor_leading_zero_decimal")] +#[test_case("scrypt_work_factor_leading_zero_octal")] +#[test_case("scrypt_work_factor_missing")] +#[test_case("scrypt_work_factor_negative")] +#[test_case("scrypt_work_factor_overflow")] +#[test_case("scrypt_work_factor_trailing_garbage")] +#[test_case("scrypt_work_factor_wrong")] +#[test_case("scrypt_work_factor_zero")] +#[test_case("stanza_bad_start")] +#[test_case("stanza_base64_padding")] +#[test_case("stanza_empty_argument")] +#[test_case("stanza_empty_body")] +#[test_case("stanza_empty_last_line")] +#[test_case("stanza_invalid_character")] +#[test_case("stanza_long_line")] +#[test_case("stanza_missing_body")] +#[test_case("stanza_missing_final_line")] +#[test_case("stanza_multiple_short_lines")] +#[test_case("stanza_no_arguments")] +#[test_case("stanza_not_canonical")] +#[test_case("stanza_spurious_cr")] +#[test_case("stanza_valid_characters")] +#[test_case("stream_bad_tag")] +#[test_case("stream_bad_tag_second_chunk")] +#[test_case("stream_bad_tag_second_chunk_full")] +#[test_case("stream_empty_payload")] +#[test_case("stream_last_chunk_empty")] +#[test_case("stream_last_chunk_full")] +#[test_case("stream_last_chunk_full_second")] +#[test_case("stream_missing_tag")] +#[test_case("stream_no_chunks")] +#[test_case("stream_no_final")] +#[test_case("stream_no_final_full")] +#[test_case("stream_no_final_two_chunks")] +#[test_case("stream_no_final_two_chunks_full")] +#[test_case("stream_no_nonce")] +#[test_case("stream_short_chunk")] +#[test_case("stream_short_nonce")] +#[test_case("stream_short_second_chunk")] +#[test_case("stream_three_chunks")] +#[test_case("stream_trailing_garbage_long")] +#[test_case("stream_trailing_garbage_short")] +#[test_case("stream_two_chunks")] +#[test_case("stream_two_final_chunks")] +#[test_case("version_unsupported")] +#[test_case("x25519")] +#[test_case("x25519_bad_tag")] +#[test_case("x25519_extra_argument")] +#[test_case("x25519_grease")] +#[test_case("x25519_identity")] +#[test_case("x25519_long_file_key")] +#[test_case("x25519_long_share")] +#[test_case("x25519_lowercase")] +#[test_case("x25519_low_order")] +#[test_case("x25519_multiple_recipients")] +#[test_case("x25519_no_match")] +#[test_case("x25519_not_canonical_body")] +#[test_case("x25519_not_canonical_share")] +#[test_case("x25519_short_share")] #[tokio::test] async fn testkit_async(filename: &str) { let testfile = TestFile::parse(filename); diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 00496938..c1709665 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -34,3 +34,9 @@ path = "fuzz_targets/header.rs" [[bin]] name = "decrypt" path = "fuzz_targets/decrypt.rs" + +[[bin]] +name = "decrypt_buffered" +path = "fuzz_targets/decrypt_buffered.rs" +test = false +doc = false diff --git a/fuzz/fuzz_targets/decrypt_buffered.rs b/fuzz/fuzz_targets/decrypt_buffered.rs new file mode 100644 index 00000000..553525e8 --- /dev/null +++ b/fuzz/fuzz_targets/decrypt_buffered.rs @@ -0,0 +1,18 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; + +use std::iter; + +use age::Decryptor; + +fuzz_target!(|data: &[u8]| { + if let Ok(decryptor) = Decryptor::new_buffered(data) { + match decryptor { + Decryptor::Recipients(d) => { + let _ = d.decrypt(iter::empty()); + } + // Don't pay the cost of scrypt while fuzzing. + Decryptor::Passphrase(_) => (), + } + } +}); diff --git a/rage/CHANGELOG.md b/rage/CHANGELOG.md index b97f600b..77e2acb9 100644 --- a/rage/CHANGELOG.md +++ b/rage/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to Rust's notion of to 1.0.0 are beta releases. ## [Unreleased] +### Changed +- Increased parsing speed of age file headers. For single-recipient encrypted + files, decryption throughput increases by 6% for medium (< 1MiB) files, and + over 40% for small (< 10kiB) files. ## [0.9.1] - 2022-03-24 ### Added diff --git a/rage/src/bin/rage-mount/main.rs b/rage/src/bin/rage-mount/main.rs index ad1fb14e..a9469568 100644 --- a/rage/src/bin/rage-mount/main.rs +++ b/rage/src/bin/rage-mount/main.rs @@ -266,7 +266,7 @@ fn main() -> Result<(), Error> { let types = opts.types; let mountpoint = opts.mountpoint; - match age::Decryptor::new(ArmoredReader::new(file))? { + match age::Decryptor::new_buffered(ArmoredReader::new(file))? { age::Decryptor::Passphrase(decryptor) => { match read_secret(&fl!("type-passphrase"), &fl!("prompt-passphrase"), None) { Ok(passphrase) => decryptor diff --git a/rage/src/bin/rage/main.rs b/rage/src/bin/rage/main.rs index 3e5fa0bf..5d45c68a 100644 --- a/rage/src/bin/rage/main.rs +++ b/rage/src/bin/rage/main.rs @@ -492,7 +492,7 @@ fn decrypt(opts: AgeOptions) -> Result<(), error::DecryptError> { ], ); - match age::Decryptor::new(ArmoredReader::new(input))? { + match age::Decryptor::new_buffered(ArmoredReader::new(input))? { age::Decryptor::Passphrase(decryptor) => { if !opts.identity.is_empty() { return Err(error::DecryptError::MixedIdentityAndPassphrase);