diff --git a/Cargo.lock b/Cargo.lock index f974b870..16a533eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -130,6 +130,39 @@ dependencies = [ "wait-timeout", ] +[[package]] +name = "async-recursion" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + [[package]] name = "async-trait" version = "0.1.73" @@ -640,9 +673,9 @@ checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "bytes-utils" @@ -1698,18 +1731,18 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pin-project" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "030ad2bc4db10a8944cb0d837f158bdfec4d4a4873ab701a95046770d11f8842" +checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec2e072ecce94ec471b13398d5402c188e76ac03cf74dd1a975161b23a3f6d9c" +checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", @@ -1936,10 +1969,12 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls", + "tokio-util 0.7.8", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", "winreg", @@ -2227,6 +2262,8 @@ checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" dependencies = [ "backtrace", "doc-comment", + "futures-core", + "pin-project", "snafu-derive", ] @@ -2465,6 +2502,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89b3cbabd3ae862100094ae433e1def582cf86451b4e9bf83aa7ac1d8a7d719" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.6.10" @@ -2497,9 +2547,14 @@ dependencies = [ name = "tough" version = "0.14.0" dependencies = [ + "async-recursion", + "async-trait", + "bytes", "chrono", "dyn-clone", "failure-server", + "futures", + "futures-core", "globset", "hex", "hex-literal", @@ -2517,6 +2572,8 @@ dependencies = [ "snafu", "tempfile", "tokio", + "tokio-test", + "tokio-util 0.7.8", "typed-path", "untrusted", "url", @@ -2668,6 +2725,8 @@ dependencies = [ "aws-sdk-ssm", "chrono", "clap", + "futures", + "futures-core", "hex", "httptest", "log", @@ -2682,6 +2741,7 @@ dependencies = [ "simplelog", "snafu", "tempfile", + "tokio", "tough", "tough-kms", "tough-ssm", @@ -2878,6 +2938,19 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +[[package]] +name = "wasm-streams" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.64" diff --git a/tough-kms/src/client.rs b/tough-kms/src/client.rs index eeca4f29..d0fba11a 100644 --- a/tough-kms/src/client.rs +++ b/tough-kms/src/client.rs @@ -1,38 +1,21 @@ // Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: MIT OR Apache-2.0 -use crate::error::{self, Result}; use aws_config::default_provider::credentials::DefaultCredentialsChain; use aws_config::default_provider::region::DefaultRegionChain; use aws_sdk_kms::Client as KmsClient; -use snafu::ResultExt; -use std::thread; /// Builds a KMS client for a given profile name. -pub(crate) fn build_client_kms(profile: Option<&str>) -> Result { - // We are cloning this so that we can send it across a thread boundary - let profile = profile.map(std::borrow::ToOwned::to_owned); - // We need to spin up a new thread to deal with the async nature of the - // AWS SDK Rust - let client: Result = thread::spawn(move || { - let runtime = tokio::runtime::Runtime::new().context(error::RuntimeCreationSnafu)?; - Ok(runtime.block_on(async_build_client_kms(profile))) - }) - .join() - .map_err(|_| error::Error::ThreadJoin {})?; - client -} - -async fn async_build_client_kms(profile: Option) -> KmsClient { +pub(crate) async fn build_client_kms(profile: Option<&str>) -> KmsClient { let config = aws_config::from_env(); let client_config = if let Some(profile) = profile { let region = DefaultRegionChain::builder() - .profile_name(&profile) + .profile_name(profile) .build() .region() .await; let creds = DefaultCredentialsChain::builder() - .profile_name(&profile) + .profile_name(profile) .region(region.clone()) .build() .await; diff --git a/tough-kms/src/error.rs b/tough-kms/src/error.rs index f71f643a..5ea8d93c 100644 --- a/tough-kms/src/error.rs +++ b/tough-kms/src/error.rs @@ -17,16 +17,6 @@ pub type Result = std::result::Result; #[non_exhaustive] #[allow(missing_docs)] pub enum Error { - /// The library failed to instantiate 'tokio Runtime'. - #[snafu(display("Unable to create tokio runtime: {}", source))] - RuntimeCreation { - source: std::io::Error, - backtrace: Backtrace, - }, - /// The library failed to join 'tokio Runtime'. - #[snafu(display("Unable to join tokio thread used to offload async workloads"))] - ThreadJoin, - /// The library failed to get public key from AWS KMS #[snafu(display( "Failed to get public key for aws-kms://{}/{} : {}", diff --git a/tough-kms/src/lib.rs b/tough-kms/src/lib.rs index e7eae59a..b7573704 100644 --- a/tough-kms/src/lib.rs +++ b/tough-kms/src/lib.rs @@ -29,6 +29,7 @@ use ring::rand::SecureRandom; use snafu::{ensure, OptionExt, ResultExt}; use std::collections::HashMap; use std::fmt; +use tough::async_trait; use tough::key_source::KeySource; use tough::schema::decoded::{Decoded, RsaPem}; use tough::schema::key::{Key, RsaKey, RsaScheme}; @@ -76,23 +77,22 @@ impl fmt::Debug for KmsKeySource { } /// Implement the `KeySource` trait. +#[async_trait] impl KeySource for KmsKeySource { - fn as_sign( + async fn as_sign( &self, ) -> std::result::Result, Box> { let kms_client = match self.client.clone() { Some(value) => value, - None => client::build_client_kms(self.profile.as_deref())?, + None => client::build_client_kms(self.profile.as_deref()).await, }; // Get the public key from AWS KMS - let fut = kms_client + let response = kms_client .get_public_key() .key_id(self.key_id.clone()) - .send(); - let response = tokio::runtime::Runtime::new() - .context(error::RuntimeCreationSnafu)? - .block_on(fut) + .send() + .await .context(error::KmsGetPublicKeySnafu { profile: self.profile.clone(), key_id: self.key_id.clone(), @@ -131,7 +131,7 @@ impl KeySource for KmsKeySource { })) } - fn write( + async fn write( &self, _value: &str, _key_id_hex: &str, @@ -166,6 +166,7 @@ impl fmt::Debug for KmsRsaKey { } } +#[async_trait] impl Sign for KmsRsaKey { fn tuf_key(&self) -> Key { // Create a Key struct for the public key @@ -179,27 +180,24 @@ impl Sign for KmsRsaKey { } } - fn sign( + async fn sign( &self, msg: &[u8], - _rng: &dyn SecureRandom, + _rng: &(dyn SecureRandom + Sync), ) -> Result, Box> { let kms_client = match self.client.clone() { Some(value) => value, - None => client::build_client_kms(self.profile.as_deref())?, + None => client::build_client_kms(self.profile.as_deref()).await, }; let blob = Blob::new(digest(&SHA256, msg).as_ref().to_vec()); - let sign_fut = kms_client + let response = kms_client .sign() .key_id(self.key_id.clone()) .message(blob) .message_type(aws_sdk_kms::types::MessageType::Digest) .signing_algorithm(self.signing_algorithm.value()) - .send(); - - let response = tokio::runtime::Runtime::new() - .context(error::RuntimeCreationSnafu)? - .block_on(sign_fut) + .send() + .await .context(error::KmsSignMessageSnafu { profile: self.profile.clone(), key_id: self.key_id.clone(), diff --git a/tough-kms/tests/all_test.rs b/tough-kms/tests/all_test.rs index 7f52f586..5d408a06 100644 --- a/tough-kms/tests/all_test.rs +++ b/tough-kms/tests/all_test.rs @@ -45,9 +45,9 @@ struct CreateKeyResp { key_id: String, } -#[test] +#[tokio::test] // Ensure public key is returned on calling tuf_key -fn check_tuf_key_success() { +async fn check_tuf_key_success() { let input = "response_public_key.json"; let key_id = String::from("alias/some_alias"); let file = File::open( @@ -67,15 +67,15 @@ fn check_tuf_key_success() { client: Some(client), signing_algorithm: RsassaPssSha256, }; - let sign = kms_key.as_sign().unwrap(); + let sign = kms_key.as_sign().await.unwrap(); let key = sign.tuf_key(); assert!(matches!(key, Key::Rsa { .. })); assert_eq!(key, expected_key); } -#[test] +#[tokio::test] // Ensure message signature is returned on calling sign -fn check_sign_success() { +async fn check_sign_success() { let resp_public_key = "response_public_key.json"; let resp_signature = "response_signature.json"; let file = File::open( @@ -96,16 +96,17 @@ fn check_sign_success() { signing_algorithm: RsassaPssSha256, }; let rng = SystemRandom::new(); - let kms_sign = kms_key.as_sign().unwrap(); + let kms_sign = kms_key.as_sign().await.unwrap(); let signature = kms_sign .sign("Some message to sign".as_bytes(), &rng) + .await .unwrap(); assert_eq!(signature, expected_signature); } -#[test] +#[tokio::test] // Ensure call to tuf_key fails when public key is not available -fn check_public_key_failure() { +async fn check_public_key_failure() { let client = test_utils::mock_client_with_status(501); let key_id = String::from("alias/some_alias"); let kms_key = KmsKeySource { @@ -114,13 +115,13 @@ fn check_public_key_failure() { client: Some(client), signing_algorithm: RsassaPssSha256, }; - let result = kms_key.as_sign(); + let result = kms_key.as_sign().await; assert!(result.is_err()); } -#[test] +#[tokio::test] // Ensure call to as_sign fails when signing algorithms are missing in get_public_key response -fn check_public_key_missing_algo() { +async fn check_public_key_missing_algo() { let input = "response_public_key_no_algo.json"; let client = test_utils::mock_client(vec![input]); let key_id = String::from("alias/some_alias"); @@ -130,7 +131,7 @@ fn check_public_key_missing_algo() { client: Some(client), signing_algorithm: RsassaPssSha256, }; - let err = kms_key.as_sign().err().unwrap(); + let err = kms_key.as_sign().await.err().unwrap(); assert_eq!( String::from( "Found public key from AWS KMS, but list of supported signing algorithm is missing" @@ -139,9 +140,9 @@ fn check_public_key_missing_algo() { ); } -#[test] +#[tokio::test] // Ensure call to as_sign fails when provided signing algorithm does not match -fn check_public_key_unmatch_algo() { +async fn check_public_key_unmatch_algo() { let input = "response_public_key_unmatch_algo.json"; let key_id = String::from("alias/some_alias"); let client = test_utils::mock_client(vec![input]); @@ -151,16 +152,16 @@ fn check_public_key_unmatch_algo() { client: Some(client), signing_algorithm: RsassaPssSha256, }; - let err = kms_key.as_sign().err().unwrap(); + let err = kms_key.as_sign().await.err().unwrap(); assert_eq!( String::from("Please provide valid signing algorithm"), err.to_string() ); } -#[test] +#[tokio::test] // Ensure sign error when Kms returns empty signature. -fn check_signature_failure() { +async fn check_signature_failure() { let resp_public_key = "response_public_key.json"; let resp_signature = "response_signature_empty.json"; let key_id = String::from("alias/some_alias"); @@ -172,8 +173,8 @@ fn check_signature_failure() { signing_algorithm: RsassaPssSha256, }; let rng = SystemRandom::new(); - let kms_sign = kms_key.as_sign().unwrap(); - let result = kms_sign.sign("Some message to sign".as_bytes(), &rng); + let kms_sign = kms_key.as_sign().await.unwrap(); + let result = kms_sign.sign("Some message to sign".as_bytes(), &rng).await; assert!(result.is_err()); let err = result.err().unwrap(); assert_eq!( @@ -182,8 +183,8 @@ fn check_signature_failure() { ); } -#[test] -fn check_write_ok() { +#[tokio::test] +async fn check_write_ok() { let key_id = String::from("alias/some_alias"); let kms_key = KmsKeySource { profile: None, @@ -191,5 +192,5 @@ fn check_write_ok() { client: None, signing_algorithm: RsassaPssSha256, }; - assert!(kms_key.write("", "").is_ok()); + assert!(kms_key.write("", "").await.is_ok()); } diff --git a/tough-ssm/src/lib.rs b/tough-ssm/src/lib.rs index a6488d6c..81356a6a 100644 --- a/tough-ssm/src/lib.rs +++ b/tough-ssm/src/lib.rs @@ -5,6 +5,7 @@ mod client; pub mod error; use snafu::{OptionExt, ResultExt}; +use tough::async_trait; use tough::key_source::KeySource; use tough::sign::{parse_keypair, Sign}; @@ -17,20 +18,19 @@ pub struct SsmKeySource { } /// Implements the KeySource trait. +#[async_trait] impl KeySource for SsmKeySource { - fn as_sign( + async fn as_sign( &self, ) -> std::result::Result, Box> { let ssm_client = client::build_client(self.profile.as_deref())?; - let fut = ssm_client + let response = ssm_client .get_parameter() .name(self.parameter_name.to_owned()) .with_decryption(true) - .send(); - let response = tokio::runtime::Runtime::new() - .context(error::RuntimeCreationSnafu)? - .block_on(fut) + .send() + .await .context(error::SsmGetParameterSnafu { profile: self.profile.clone(), parameter_name: &self.parameter_name, @@ -52,14 +52,14 @@ impl KeySource for SsmKeySource { Ok(sign) } - fn write( + async fn write( &self, value: &str, key_id_hex: &str, ) -> std::result::Result<(), Box> { let ssm_client = client::build_client(self.profile.as_deref())?; - let fut = ssm_client + ssm_client .put_parameter() .name(self.parameter_name.to_owned()) .description(key_id_hex.to_owned()) @@ -67,15 +67,13 @@ impl KeySource for SsmKeySource { .overwrite(true) .set_type(Some(aws_sdk_ssm::types::ParameterType::SecureString)) .value(value.to_owned()) - .send(); - - tokio::runtime::Runtime::new() - .context(error::RuntimeCreationSnafu)? - .block_on(fut) + .send() + .await .context(error::SsmPutParameterSnafu { profile: self.profile.clone(), parameter_name: &self.parameter_name, })?; + Ok(()) } } diff --git a/tough/Cargo.toml b/tough/Cargo.toml index 8dfff961..5be8827a 100644 --- a/tough/Cargo.toml +++ b/tough/Cargo.toml @@ -9,21 +9,28 @@ keywords = ["tuf", "update", "repository"] edition = "2018" [dependencies] +async-recursion = "1.0.5" +async-trait = "0.1.73" +bytes = "1.5.0" chrono = { version = "0.4", default-features = false, features = ["std", "alloc", "serde", "clock"] } dyn-clone = "1" +futures = "0.3.28" +futures-core = "0.3.28" globset = { version = "0.4" } hex = "0.4" log = "0.4" olpc-cjson = { version = "0.1", path = "../olpc-cjson" } pem = "3" percent-encoding = "2" -reqwest = { version = "0.11", optional = true, default-features = false, features = ["blocking"] } +reqwest = { version = "0.11", optional = true, default-features = false, features = ["stream"] } ring = { version = "0.16", features = ["std"] } serde = { version = "1", features = ["derive"] } serde_json = "1" serde_plain = "1" -snafu = "0.7" +snafu = { version = "0.7", features = ["futures"] } tempfile = "3" +tokio = { version = "1.0", default-features = false, features = ["io-util", "sync", "fs", "time"] } +tokio-util = { version = "0.7.8", features = ["io"] } typed-path = "0.6" untrusted = "0.7" url = "2" @@ -34,7 +41,8 @@ failure-server = { path = "../integ/failure-server", version = "0.1.0" } hex-literal = "0.4" httptest = "0.15" maplit = "1" -tokio = { version = "1.0", features = ["rt-multi-thread"] } +tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } +tokio-test = "0.4.3" [features] http = ["reqwest"] diff --git a/tough/src/cache.rs b/tough/src/cache.rs index a3ca4745..97e70502 100644 --- a/tough/src/cache.rs +++ b/tough/src/cache.rs @@ -1,10 +1,14 @@ use crate::error::{self, Result}; use crate::fetch::{fetch_max_size, fetch_sha256}; use crate::schema::{RoleType, Target}; +use crate::transport::IntoVec; use crate::{encode_filename, Prefix, Repository, TargetName}; -use snafu::{OptionExt, ResultExt}; -use std::io::{Read, Write}; +use bytes::Bytes; +use futures::StreamExt; +use futures_core::stream::BoxStream; +use snafu::{futures::TryStreamExt, OptionExt, ResultExt}; use std::path::Path; +use tokio::io::AsyncWriteExt; impl Repository { /// Cache an entire or partial repository to disk, including all required metadata. @@ -15,7 +19,7 @@ impl Repository { /// * `targets_subset` is the list of targets to include in the cached repo. If no subset is /// specified (`None`), then *all* targets are included in the cache. /// * `cache_root_chain` specifies whether or not we will cache all versions of `root.json`. - pub fn cache( + pub async fn cache( &self, metadata_outdir: P1, targets_outdir: P2, @@ -28,35 +32,35 @@ impl Repository { S: AsRef, { // Create the output directories if the do not exist. - std::fs::create_dir_all(metadata_outdir.as_ref()).context( - error::CacheDirectoryCreateSnafu { + tokio::fs::create_dir_all(metadata_outdir.as_ref()) + .await + .context(error::CacheDirectoryCreateSnafu { path: metadata_outdir.as_ref(), - }, - )?; - std::fs::create_dir_all(targets_outdir.as_ref()).context( - error::CacheDirectoryCreateSnafu { + })?; + tokio::fs::create_dir_all(targets_outdir.as_ref()) + .await + .context(error::CacheDirectoryCreateSnafu { path: targets_outdir.as_ref(), - }, - )?; + })?; // Fetch targets and save them to the outdir if let Some(target_list) = targets_subset { for raw_name in target_list { let target_name = TargetName::new(raw_name.as_ref())?; - self.cache_target(&targets_outdir, &target_name)?; + self.cache_target(&targets_outdir, &target_name).await?; } } else { let targets = &self.targets.signed.targets_map(); for target_name in targets.keys() { - self.cache_target(&targets_outdir, target_name)?; + self.cache_target(&targets_outdir, target_name).await?; } } // Cache all metadata - self.cache_metadata_impl(&metadata_outdir)?; + self.cache_metadata_impl(&metadata_outdir).await?; if cache_root_chain { - self.cache_root_chain(&metadata_outdir)?; + self.cache_root_chain(&metadata_outdir).await?; } Ok(()) } @@ -66,27 +70,27 @@ impl Repository { /// /// * `metadata_outdir` is the directory where cached metadata files will be saved. /// * `cache_root_chain` specifies whether or not we will cache all versions of `root.json`. - pub fn cache_metadata

(&self, metadata_outdir: P, cache_root_chain: bool) -> Result<()> + pub async fn cache_metadata

(&self, metadata_outdir: P, cache_root_chain: bool) -> Result<()> where P: AsRef, { // Create the output directory if it does not exist. - std::fs::create_dir_all(metadata_outdir.as_ref()).context( - error::CacheDirectoryCreateSnafu { + tokio::fs::create_dir_all(metadata_outdir.as_ref()) + .await + .context(error::CacheDirectoryCreateSnafu { path: metadata_outdir.as_ref(), - }, - )?; + })?; - self.cache_metadata_impl(&metadata_outdir)?; + self.cache_metadata_impl(&metadata_outdir).await?; if cache_root_chain { - self.cache_root_chain(metadata_outdir)?; + self.cache_root_chain(metadata_outdir).await?; } Ok(()) } /// Cache repository metadata files, including delegated targets metadata - fn cache_metadata_impl

(&self, metadata_outdir: P) -> Result<()> + async fn cache_metadata_impl

(&self, metadata_outdir: P) -> Result<()> where P: AsRef, { @@ -95,19 +99,22 @@ impl Repository { self.max_snapshot_size()?, "timestamp.json", &metadata_outdir, - )?; + ) + .await?; self.cache_file_from_transport( self.targets_filename().as_str(), self.limits.max_targets_size, "max_targets_size argument", &metadata_outdir, - )?; + ) + .await?; self.cache_file_from_transport( "timestamp.json", self.limits.max_timestamp_size, "max_timestamp_size argument", &metadata_outdir, - )?; + ) + .await?; for name in self.targets.signed.role_names() { if let Some(filename) = self.delegated_filename(name) { @@ -116,7 +123,8 @@ impl Repository { self.limits.max_targets_size, "max_targets_size argument", &metadata_outdir, - )?; + ) + .await?; } } @@ -124,7 +132,7 @@ impl Repository { } /// Cache all versions of root.json less than or equal to the current version. - fn cache_root_chain

(&self, outdir: P) -> Result<()> + async fn cache_root_chain

(&self, outdir: P) -> Result<()> where P: AsRef, { @@ -135,7 +143,8 @@ impl Repository { self.limits.max_root_size, "max_root_size argument", &outdir, - )?; + ) + .await?; } Ok(()) } @@ -176,40 +185,45 @@ impl Repository { } /// Copies a file using `Transport` to `outdir`. - fn cache_file_from_transport>( + async fn cache_file_from_transport>( &self, filename: &str, max_size: u64, max_size_specifier: &'static str, outdir: P, ) -> Result<()> { - let mut read = fetch_max_size( + let url = self + .metadata_base_url + .join(filename) + .with_context(|_| error::JoinUrlSnafu { + path: filename, + url: self.metadata_base_url.clone(), + })?; + let stream = fetch_max_size( self.transport.as_ref(), - self.metadata_base_url - .join(filename) - .context(error::JoinUrlSnafu { - path: filename, - url: self.metadata_base_url.clone(), - })?, + url.clone(), max_size, max_size_specifier, - )?; + ) + .await?; let outpath = outdir.as_ref().join(filename); - let mut file = std::fs::File::create(&outpath).context(error::CacheFileWriteSnafu { - path: outpath.clone(), + let mut file = tokio::fs::File::create(&outpath).await.with_context(|_| { + error::CacheFileWriteSnafu { + path: outpath.clone(), + } })?; - let mut root_file_data = Vec::new(); - read.read_to_end(&mut root_file_data) - .context(error::CacheFileReadSnafu { - url: self.metadata_base_url.clone(), - })?; + let root_file_data = stream + .into_vec() + .await + .context(error::TransportSnafu { url })?; file.write_all(&root_file_data) + .await .context(error::CacheFileWriteSnafu { path: outpath }) } /// Saves a signed target to the specified `outdir`. Retains the digest-prepended filename if /// consistent snapshots are used. - fn cache_target>(&self, outdir: P, name: &TargetName) -> Result<()> { + async fn cache_target>(&self, outdir: P, name: &TargetName) -> Result<()> { self.save_target( name, outdir, @@ -219,6 +233,7 @@ impl Repository { Prefix::None }, ) + .await } /// Gets the max size of the snapshot.json file as specified by the timestamp file. @@ -255,23 +270,28 @@ impl Repository { /// Fetches the signed target using `Transport`. Aborts with error if the fetched target is /// larger than its signed size. - pub(crate) fn fetch_target( + pub(crate) async fn fetch_target( &self, target: &Target, digest: &[u8], filename: &str, - ) -> Result { - fetch_sha256( + ) -> Result>> { + let url = self + .targets_base_url + .join(filename) + .with_context(|_| error::JoinUrlSnafu { + path: filename, + url: self.targets_base_url.clone(), + })?; + Ok(fetch_sha256( self.transport.as_ref(), - self.targets_base_url - .join(filename) - .context(error::JoinUrlSnafu { - path: filename, - url: self.targets_base_url.clone(), - })?, + url.clone(), target.length, "targets.json", digest, ) + .await? + .context(error::TransportSnafu { url }) + .boxed()) } } diff --git a/tough/src/datastore.rs b/tough/src/datastore.rs index 5851f15e..aa3dfb4d 100644 --- a/tough/src/datastore.rs +++ b/tough/src/datastore.rs @@ -6,11 +6,11 @@ use chrono::{DateTime, Utc}; use log::debug; use serde::Serialize; use snafu::{ensure, ResultExt}; -use std::fs::{self, File}; -use std::io::{ErrorKind, Read}; +use std::io::ErrorKind; use std::path::{Path, PathBuf}; -use std::sync::{Arc, PoisonError, RwLock, RwLockReadGuard, RwLockWriteGuard}; +use std::sync::Arc; use tempfile::TempDir; +use tokio::sync::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; /// `Datastore` persists TUF metadata files. #[derive(Debug, Clone)] @@ -18,7 +18,7 @@ pub(crate) struct Datastore { /// A lock around retrieving the datastore path. path_lock: Arc>, /// A lock to treat the system_time function as a critical section. - time_lock: Arc>, + time_lock: Arc>, } impl Datastore { @@ -28,34 +28,26 @@ impl Datastore { None => DatastorePath::TempDir(TempDir::new().context(error::DatastoreInitSnafu)?), Some(p) => DatastorePath::Path(p), })), - time_lock: Arc::new(RwLock::new(())), + time_lock: Arc::new(Mutex::new(())), }) } - // Because we are not actually changing the underlying data in the lock, we can ignore when a - // lock is poisoned. - - fn read(&self) -> RwLockReadGuard<'_, DatastorePath> { - self.path_lock - .read() - .unwrap_or_else(PoisonError::into_inner) + async fn read(&self) -> RwLockReadGuard<'_, DatastorePath> { + self.path_lock.read().await } - fn write(&self) -> RwLockWriteGuard<'_, DatastorePath> { - self.path_lock - .write() - .unwrap_or_else(PoisonError::into_inner) + async fn write(&self) -> RwLockWriteGuard<'_, DatastorePath> { + self.path_lock.write().await } - /// Get a reader to a file in the datastore. Caution, this is *not* thread safe. A lock is - /// briefly created on the datastore when the read object is created, but it is released at the - /// end of this function. + /// Get contents of a file in the datastore. This function is thread safe. /// /// TODO: [provide a thread safe interface](https://github.com/awslabs/tough/issues/602) /// - pub(crate) fn reader(&self, file: &str) -> Result> { - let path = self.read().path().join(file); - match File::open(&path) { + pub(crate) async fn bytes(&self, file: &str) -> Result>> { + let lock = &self.read().await; + let path = lock.path().join(file); + match tokio::fs::read(&path).await { Ok(file) => Ok(Some(file)), Err(err) => match err.kind() { ErrorKind::NotFound => Ok(None), @@ -65,23 +57,24 @@ impl Datastore { } /// Writes a JSON metadata file in the datastore. This function is thread safe. - pub(crate) fn create(&self, file: &str, value: &T) -> Result<()> { - let path = self.write().path().join(file); - serde_json::to_writer_pretty( - File::create(&path).context(error::DatastoreCreateSnafu { path: &path })?, - value, - ) - .context(error::DatastoreSerializeSnafu { + pub(crate) async fn create(&self, file: &str, value: &T) -> Result<()> { + let lock = &self.write().await; + let path = lock.path().join(file); + let bytes = serde_json::to_vec(value).with_context(|_| error::DatastoreSerializeSnafu { what: format!("{file} in datastore"), - path, - }) + path: path.clone(), + })?; + tokio::fs::write(&path, bytes) + .await + .context(error::DatastoreCreateSnafu { path: &path }) } /// Deletes a file from the datastore. This function is thread safe. - pub(crate) fn remove(&self, file: &str) -> Result<()> { - let path = self.write().path().join(file); + pub(crate) async fn remove(&self, file: &str) -> Result<()> { + let lock = self.write().await; + let path = lock.path().join(file); debug!("removing '{}'", path.display()); - match fs::remove_file(&path) { + match tokio::fs::remove_file(&path).await { Ok(()) => Ok(()), Err(err) => match err.kind() { ErrorKind::NotFound => Ok(()), @@ -92,21 +85,16 @@ impl Datastore { /// Ensures that system time has not stepped backward since it was last sampled. This function /// is protected by a lock guard to ensure thread safety. - pub(crate) fn system_time(&self) -> Result> { + pub(crate) async fn system_time(&self) -> Result> { // Treat this function as a critical section. This lock is not used for anything else. - let lock = self.time_lock.write().map_err(|e| { - // Painful error type that has a reference and lifetime. Convert it to a message string. - error::DatastoreTimeLockSnafu { - message: e.to_string(), - } - .build() - })?; + let lock = self.time_lock.lock().await; let file = "latest_known_time.json"; // Load the latest known system time, if it exists let poss_latest_known_time = self - .reader(file)? - .map(serde_json::from_reader::<_, DateTime>); + .bytes(file) + .await? + .map(|b| serde_json::from_slice::>(&b)); // Get 'current' system time let sys_time = Utc::now(); @@ -123,7 +111,7 @@ impl Datastore { } // Store the latest known time // Serializes RFC3339 time string and store to datastore - self.create(file, &sys_time)?; + self.create(file, &sys_time).await?; // Explicitly drop the lock to avoid any compiler optimization. drop(lock); diff --git a/tough/src/editor/keys.rs b/tough/src/editor/keys.rs index 5661e7d6..652a9335 100644 --- a/tough/src/editor/keys.rs +++ b/tough/src/editor/keys.rs @@ -14,10 +14,10 @@ pub(crate) type KeyList = HashMap, Box>; impl KeyHolder { /// Creates a key list for the provided keys - pub(crate) fn get_keys(&self, keys: &[Box]) -> Result { + pub(crate) async fn get_keys(&self, keys: &[Box]) -> Result { match self { - Self::Delegations(delegations) => get_targets_keys(delegations, keys), - Self::Root(root) => get_root_keys(root, keys), + Self::Delegations(delegations) => get_targets_keys(delegations, keys).await, + Self::Root(root) => get_root_keys(root, keys).await, } } @@ -74,12 +74,15 @@ impl KeyHolder { /// Gets the corresponding keys from Root (root.json) for the given `KeySource`s. /// This is a convenience function that wraps `Root.key_id()` for multiple /// `KeySource`s. -pub(crate) fn get_root_keys(root: &Root, keys: &[Box]) -> Result { +pub(crate) async fn get_root_keys(root: &Root, keys: &[Box]) -> Result { let mut root_keys = KeyList::new(); for source in keys { // Get a keypair from the given source - let key_pair = source.as_sign().context(error::KeyPairFromKeySourceSnafu)?; + let key_pair = source + .as_sign() + .await + .context(error::KeyPairFromKeySourceSnafu)?; // If the keypair matches any of the keys in the root.json, // add its ID and corresponding keypair the map to be returned @@ -94,14 +97,17 @@ pub(crate) fn get_root_keys(root: &Root, keys: &[Box]) -> Result< /// Gets the corresponding keys from delegations for the given `KeySource`s. /// This is a convenience function that wraps `Delegations.key_id()` for multiple /// `KeySource`s. -pub(crate) fn get_targets_keys( +pub(crate) async fn get_targets_keys( delegations: &Delegations, keys: &[Box], ) -> Result { let mut delegations_keys = KeyList::new(); for source in keys { // Get a keypair from the given source - let key_pair = source.as_sign().context(error::KeyPairFromKeySourceSnafu)?; + let key_pair = source + .as_sign() + .await + .context(error::KeyPairFromKeySourceSnafu)?; // If the keypair matches any of the keys in the delegations metadata, // add its ID and corresponding keypair the map to be returned if let Some(key_id) = delegations.key_id(key_pair.as_ref()) { diff --git a/tough/src/editor/mod.rs b/tough/src/editor/mod.rs index a5354fff..377d725e 100644 --- a/tough/src/editor/mod.rs +++ b/tough/src/editor/mod.rs @@ -20,7 +20,7 @@ use crate::schema::{ Hashes, KeyHolder, PathSet, Role, RoleType, Root, Signed, Snapshot, SnapshotMeta, Target, Targets, Timestamp, TimestampMeta, }; -use crate::transport::Transport; +use crate::transport::{IntoVec, Transport}; use crate::{encode_filename, Limits}; use crate::{Repository, TargetName}; use chrono::{DateTime, Utc}; @@ -87,15 +87,16 @@ pub struct RepositoryEditor { impl RepositoryEditor { /// Create a new, bare `RepositoryEditor` - pub fn new

(root_path: P) -> Result + pub async fn new

(root_path: P) -> Result where P: AsRef, { // Read and parse the root.json. Without a good root, it doesn't // make sense to continue let root_path = root_path.as_ref(); - let root_buf = - std::fs::read(root_path).context(error::FileReadSnafu { path: root_path })?; + let root_buf = tokio::fs::read(root_path) + .await + .context(error::FileReadSnafu { path: root_path })?; let root_buf_len = root_buf.len() as u64; let root = serde_json::from_slice::>(&root_buf) .context(error::FileParseJsonSnafu { path: root_path })?; @@ -143,11 +144,11 @@ impl RepositoryEditor { /// `RepositoryEditor`. This `RepositoryEditor` will include all of the targets /// and bits of _extra metadata from the roles included. It will not, however, /// include the versions or expirations and the user is expected to set them. - pub fn from_repo

(root_path: P, repo: Repository) -> Result + pub async fn from_repo

(root_path: P, repo: Repository) -> Result where P: AsRef, { - let mut editor = RepositoryEditor::new(root_path)?; + let mut editor = RepositoryEditor::new(root_path).await?; editor.targets(repo.targets)?; editor.snapshot(repo.snapshot.signed)?; editor.timestamp(repo.timestamp.signed)?; @@ -162,11 +163,11 @@ impl RepositoryEditor { /// While `RepositoryEditor`s fields are all `Option`s, this step requires, /// at the very least, that the "version" and "expiration" field is set for /// each role; e.g. `targets_version`, `targets_expires`, etc. - pub fn sign(mut self, keys: &[Box]) -> Result { + pub async fn sign(mut self, keys: &[Box]) -> Result { let rng = SystemRandom::new(); let root = KeyHolder::Root(self.signed_root.signed.signed.clone()); // Sign the targets editor if able to with the provided keys - self.sign_targets_editor(keys)?; + self.sign_targets_editor(keys).await?; let targets = self.signed_targets.clone().context(error::NoTargetsSnafu)?; let delegated_targets = targets.signed.signed_delegated_targets(); let signed_targets = SignedRole::from_signed(targets)?; @@ -189,12 +190,10 @@ impl RepositoryEditor { }) }; - let signed_snapshot = self - .build_snapshot(&signed_targets, &signed_delegated_targets) - .and_then(|snapshot| SignedRole::new(snapshot, &root, keys, &rng))?; - let signed_timestamp = self - .build_timestamp(&signed_snapshot) - .and_then(|timestamp| SignedRole::new(timestamp, &root, keys, &rng))?; + let signed_snapshot = self.build_snapshot(&signed_targets, &signed_delegated_targets)?; + let signed_snapshot = SignedRole::new(signed_snapshot, &root, keys, &rng).await?; + let signed_timestamp = self.build_timestamp(&signed_snapshot)?; + let signed_timestamp = SignedRole::new(signed_timestamp, &root, keys, &rng).await?; // This validation can only be done from the top level targets.json role. This check verifies // that each target's delegate hierarchy is a match (i.e. its delegate ownership is valid). @@ -289,11 +288,11 @@ impl RepositoryEditor { /// no multithreading or parallelism is used. If you have a large number /// of targets to add, and require advanced performance, you may want to /// construct `Target`s directly in parallel and use `add_target()`. - pub fn add_target_path

(&mut self, target_path: P) -> Result<&mut Self> + pub async fn add_target_path

(&mut self, target_path: P) -> Result<&mut Self> where P: AsRef, { - let (target_name, target) = RepositoryEditor::build_target(target_path)?; + let (target_name, target) = RepositoryEditor::build_target(target_path).await?; self.add_target(target_name, target)?; Ok(self) } @@ -301,12 +300,12 @@ impl RepositoryEditor { /// Add a list of target paths to the repository /// /// See the note on `add_target_path()` regarding performance. - pub fn add_target_paths

(&mut self, targets: Vec

) -> Result<&mut Self> + pub async fn add_target_paths

(&mut self, targets: Vec

) -> Result<&mut Self> where P: AsRef, { for target in targets { - let (target_name, target) = RepositoryEditor::build_target(target)?; + let (target_name, target) = RepositoryEditor::build_target(target).await?; self.add_target(target_name, target)?; } @@ -314,7 +313,7 @@ impl RepositoryEditor { } /// Builds a target struct for the given path - pub fn build_target

(target_path: P) -> Result<(TargetName, Target)> + pub async fn build_target

(target_path: P) -> Result<(TargetName, Target)> where P: AsRef, { @@ -331,6 +330,7 @@ impl RepositoryEditor { // Build a Target from the path given. If it is not a file, this will fail let target = Target::from_path(target_path) + .await .context(error::TargetFromPathSnafu { path: target_path })?; Ok((target_name, target)) @@ -346,7 +346,7 @@ impl RepositoryEditor { /// Delegate target with name as a `DelegatedRole` of the `Targets` in `targets_editor` /// This should be used if a role needs to be created by a user with `snapshot.json`, /// `timestamp.json`, and the new role's keys. - pub fn delegate_role( + pub async fn delegate_role( &mut self, name: &str, key_source: &[Box], @@ -360,13 +360,14 @@ impl RepositoryEditor { // Set the version and expiration new_targets_editor.version(version).expires(expiration); // Sign the new targets - let new_targets = new_targets_editor.create_signed(key_source)?; + let new_targets = new_targets_editor.create_signed(key_source).await?; // Find the keyids for key_source let mut keyids = Vec::new(); let mut key_pairs = HashMap::new(); for source in key_source { let key_pair = source .as_sign() + .await .context(error::KeyPairFromKeySourceSnafu)? .tuf_key(); keyids.push( @@ -432,9 +433,9 @@ impl RepositoryEditor { /// Takes the current Targets from `targets_editor` and inserts the role to its proper place in `signed_targets` /// Sets `targets_editor` to None /// Must be called before `change_delegated_targets()` - pub fn sign_targets_editor(&mut self, keys: &[Box]) -> Result<&mut Self> { + pub async fn sign_targets_editor(&mut self, keys: &[Box]) -> Result<&mut Self> { if let Some(targets_editor) = self.targets_editor.as_mut() { - let (name, targets) = targets_editor.create_signed(keys)?.targets(); + let (name, targets) = targets_editor.create_signed(keys).await?.targets(); if name == "targets" { self.signed_targets = Some(targets); } else { @@ -495,7 +496,7 @@ impl RepositoryEditor { /// `metadata_url` and update the repository's metadata for the role /// This method uses the result of `SignedDelegatedTargets::write()` /// Clears the current `targets_editor` - pub fn update_delegated_targets( + pub async fn update_delegated_targets( &mut self, name: &str, metadata_url: &str, @@ -522,15 +523,20 @@ impl RepositoryEditor { filename: encoded_filename, url: metadata_base_url.clone(), })?; - let reader = Box::new(fetch_max_size( + let stream = fetch_max_size( transport.as_ref(), - role_url, + role_url.clone(), limits.max_targets_size, "max targets limit", - )?); + ) + .await?; + let data = stream + .into_vec() + .await + .context(error::TransportSnafu { url: role_url })?; // Load incoming role metadata as Signed let mut role: Signed = - serde_json::from_reader(reader).context(error::ParseMetadataSnafu { + serde_json::from_slice(&data).context(error::ParseMetadataSnafu { role: RoleType::Targets, })?; //verify role with the parent delegation @@ -587,15 +593,20 @@ impl RepositoryEditor { filename: encoded_filename, url: metadata_base_url.clone(), })?; - let reader = Box::new(fetch_max_size( + let stream = fetch_max_size( transport.as_ref(), - role_url, + role_url.clone(), limits.max_targets_size, "max targets limit", - )?); + ) + .await?; + let data = stream + .into_vec() + .await + .context(error::TransportSnafu { url: role_url })?; // Load new role metadata as Signed - let new_role: Signed = serde_json::from_reader(reader) - .context(error::ParseMetadataSnafu { + let new_role: Signed = + serde_json::from_slice(&data).context(error::ParseMetadataSnafu { role: RoleType::Targets, })?; // verify the role @@ -629,7 +640,7 @@ impl RepositoryEditor { /// Adds a role to the targets currently in `targets_editor` /// using a metadata file located at `metadata_url`/`name`.json /// `add_role()` uses `TargetsEditor::add_role()` to add a role from an existing metadata file. - pub fn add_role( + pub async fn add_role( &mut self, name: &str, metadata_url: &str, @@ -646,7 +657,8 @@ impl RepositoryEditor { self.targets_editor_mut()?.limits(limits); self.targets_editor_mut()?.transport(transport.clone()); self.targets_editor_mut()? - .add_role(name, metadata_url, paths, threshold, keys)?; + .add_role(name, metadata_url, paths, threshold, keys) + .await?; Ok(self) } diff --git a/tough/src/editor/signed.rs b/tough/src/editor/signed.rs index df3ce10e..eea91d05 100644 --- a/tough/src/editor/signed.rs +++ b/tough/src/editor/signed.rs @@ -7,12 +7,14 @@ //! signing, ready to be written to disk. use crate::error::{self, Result}; -use crate::io::DigestAdapter; +use crate::io::{is_file, DigestAdapter}; use crate::key_source::KeySource; use crate::schema::{ DelegatedTargets, KeyHolder, Role, RoleType, Root, Signature, Signed, Snapshot, Target, Targets, Timestamp, }; +use async_trait::async_trait; +use futures::TryStreamExt; use olpc_cjson::CanonicalFormatter; use ring::digest::{digest, SHA256, SHA256_OUTPUT_LEN}; use ring::rand::SecureRandom; @@ -20,14 +22,15 @@ use serde::{Deserialize, Serialize}; use serde_plain::derive_fromstr_from_deserialize; use snafu::{ensure, OptionExt, ResultExt}; use std::collections::HashMap; -use std::fs; +use std::future::{ready, Future}; +use tokio::fs::{canonicalize, copy, create_dir_all, remove_file, symlink_metadata}; #[cfg(not(target_os = "windows"))] -use std::os::unix::fs::symlink; +use tokio::fs::symlink; #[cfg(target_os = "windows")] -use std::os::windows::fs::symlink_file as symlink; +use tokio::fs::symlink_file as symlink; -use crate::TargetName; +use crate::{FilesystemTransport, TargetName, Transport}; use std::borrow::Cow; use std::path::{Path, PathBuf}; use url::Url; @@ -53,13 +56,13 @@ where T: Role + Serialize, { /// Creates a new `SignedRole` - pub fn new( + pub async fn new( role: T, key_holder: &KeyHolder, keys: &[Box], - rng: &dyn SecureRandom, + rng: &(dyn SecureRandom + Sync), ) -> Result { - let root_keys = key_holder.get_keys(keys)?; + let root_keys = key_holder.get_keys(keys).await?; let role_keys = key_holder.role_keys(role.role_id())?; // Ensure the keys we have available to us will allow us @@ -86,6 +89,7 @@ where for (signing_key_id, signing_key) in valid_keys { let sig = signing_key .sign(&data, rng) + .await .context(error::SignMessageSnafu)?; // Add the signatures to the `Signed` struct for this role @@ -154,17 +158,21 @@ where /// Write the current role's buffer to the given directory with the /// appropriate file name. - pub fn write

(&self, outdir: P, consistent_snapshot: bool) -> Result<()> + pub async fn write

(&self, outdir: P, consistent_snapshot: bool) -> Result<()> where P: AsRef, { let outdir = outdir.as_ref(); - std::fs::create_dir_all(outdir).context(error::DirCreateSnafu { path: outdir })?; + tokio::fs::create_dir_all(outdir) + .await + .context(error::DirCreateSnafu { path: outdir })?; let filename = self.signed.signed.filename(consistent_snapshot); let path = outdir.join(filename); - std::fs::write(&path, &self.buffer).context(error::FileWriteSnafu { path }) + tokio::fs::write(&path, &self.buffer) + .await + .context(error::FileWriteSnafu { path }) } /// Append the old signatures for root role @@ -236,17 +244,19 @@ pub struct SignedRepository { impl SignedRepository { /// Writes the metadata to the given directory. If consistent snapshots /// are used, the appropriate files are prefixed with their version. - pub fn write

(&self, outdir: P) -> Result<()> + pub async fn write

(&self, outdir: P) -> Result<()> where P: AsRef, { let consistent_snapshot = self.root.signed.signed.consistent_snapshot; - self.root.write(&outdir, consistent_snapshot)?; - self.targets.write(&outdir, consistent_snapshot)?; - self.snapshot.write(&outdir, consistent_snapshot)?; - self.timestamp.write(&outdir, consistent_snapshot)?; + self.root.write(&outdir, consistent_snapshot).await?; + self.targets.write(&outdir, consistent_snapshot).await?; + self.snapshot.write(&outdir, consistent_snapshot).await?; + self.timestamp.write(&outdir, consistent_snapshot).await?; if let Some(delegated_targets) = &self.delegated_targets { - delegated_targets.write(&outdir, consistent_snapshot)?; + delegated_targets + .write(&outdir, consistent_snapshot) + .await?; } Ok(()) } @@ -259,7 +269,7 @@ impl SignedRepository { /// if the filename exists in `Targets`, the file's sha256 is compared /// against the data in `Targets`. If this data does not match, the /// method will fail. - pub fn link_targets( + pub async fn link_targets( &self, indir: P1, outdir: P2, @@ -275,6 +285,7 @@ impl SignedRepository { Self::link_target, replace_behavior, ) + .await } /// Crawls a given directory and copies any targets found to the given @@ -285,7 +296,7 @@ impl SignedRepository { /// if the filename exists in `Targets`, the file's sha256 is compared /// against the data in `Targets`. If this data does not match, the /// method will fail. - pub fn copy_targets( + pub async fn copy_targets( &self, indir: P1, outdir: P2, @@ -301,6 +312,7 @@ impl SignedRepository { Self::copy_target, replace_behavior, ) + .await } /// Symlinks a single target to the desired directory. If `target_filename` is given, it @@ -309,7 +321,7 @@ impl SignedRepository { /// the repo with a different hash, or if it has the same hash but is not a symlink. Using the /// `replace_behavior` parameter, you can decide what happens if it exists with the same hash /// and file type - skip, fail, or replace. - pub fn link_target( + pub async fn link_target( &self, input_path: &Path, outdir: &Path, @@ -317,19 +329,28 @@ impl SignedRepository { target_filename: Option<&TargetName>, ) -> Result<()> { ensure!( - input_path.is_file(), + is_file(input_path).await, error::PathIsNotFileSnafu { path: input_path } ); - match self.target_path(input_path, outdir, target_filename)? { + match self + .target_path(input_path, outdir, target_filename) + .await? + { TargetPath::New { path } => { - symlink(input_path, &path).context(error::LinkCreateSnafu { path })?; + symlink(input_path, &path) + .await + .context(error::LinkCreateSnafu { path })?; } TargetPath::Symlink { path } => match replace_behavior { PathExists::Skip => {} PathExists::Fail => error::PathExistsFailSnafu { path }.fail()?, PathExists::Replace => { - fs::remove_file(&path).context(error::RemoveTargetSnafu { path: &path })?; - symlink(input_path, &path).context(error::LinkCreateSnafu { path })?; + remove_file(&path) + .await + .context(error::RemoveTargetSnafu { path: &path })?; + symlink(input_path, &path) + .await + .context(error::LinkCreateSnafu { path })?; } }, TargetPath::File { path } => { @@ -351,7 +372,7 @@ impl SignedRepository { /// with a different hash, or if it has the same hash but is not a regular file. Using the /// `replace_behavior` parameter, you can decide what happens if it exists with the same hash /// and file type - skip, fail, or replace. - pub fn copy_target( + pub async fn copy_target( &self, input_path: &Path, outdir: &Path, @@ -359,19 +380,28 @@ impl SignedRepository { target_filename: Option<&TargetName>, ) -> Result<()> { ensure!( - input_path.is_file(), + is_file(input_path).await, error::PathIsNotFileSnafu { path: input_path } ); - match self.target_path(input_path, outdir, target_filename)? { + match self + .target_path(input_path, outdir, target_filename) + .await? + { TargetPath::New { path } => { - fs::copy(input_path, &path).context(error::FileWriteSnafu { path })?; + copy(input_path, &path) + .await + .context(error::FileWriteSnafu { path })?; } TargetPath::File { path } => match replace_behavior { PathExists::Skip => {} PathExists::Fail => error::PathExistsFailSnafu { path }.fail()?, PathExists::Replace => { - fs::remove_file(&path).context(error::RemoveTargetSnafu { path: &path })?; - fs::copy(input_path, &path).context(error::FileWriteSnafu { path })?; + remove_file(&path) + .await + .context(error::RemoveTargetSnafu { path: &path })?; + copy(input_path, &path) + .await + .context(error::FileWriteSnafu { path })?; } }, TargetPath::Symlink { path } => { @@ -410,12 +440,12 @@ pub struct SignedDelegatedTargets { impl SignedDelegatedTargets { /// Writes the metadata to the given directory. If consistent snapshots /// are used, the appropriate files are prefixed with their version. - pub fn write

(&self, outdir: P, consistent_snapshot: bool) -> Result<()> + pub async fn write

(&self, outdir: P, consistent_snapshot: bool) -> Result<()> where P: AsRef, { for targets in &self.roles { - targets.write(&outdir, consistent_snapshot)?; + targets.write(&outdir, consistent_snapshot).await?; } Ok(()) } @@ -433,7 +463,7 @@ impl SignedDelegatedTargets { /// if the filename exists in `Targets`, the file's sha256 is compared /// against the data in `Targets`. If this data does not match, the /// method will fail. - pub fn link_targets( + pub async fn link_targets( &self, indir: P1, outdir: P2, @@ -449,6 +479,7 @@ impl SignedDelegatedTargets { Self::link_target, replace_behavior, ) + .await } /// Crawls a given directory and copies any targets found to the given @@ -459,7 +490,7 @@ impl SignedDelegatedTargets { /// if the filename exists in `Targets`, the file's sha256 is compared /// against the data in `Targets`. If this data does not match, the /// method will fail. - pub fn copy_targets( + pub async fn copy_targets( &self, indir: P1, outdir: P2, @@ -475,6 +506,7 @@ impl SignedDelegatedTargets { Self::copy_target, replace_behavior, ) + .await } /// Symlinks a single target to the desired directory. If `target_filename` is given, it @@ -483,7 +515,7 @@ impl SignedDelegatedTargets { /// the repo with a different hash, or if it has the same hash but is not a symlink. Using the /// `replace_behavior` parameter, you can decide what happens if it exists with the same hash /// and file type - skip, fail, or replace. - pub fn link_target( + pub async fn link_target( &self, input_path: &Path, outdir: &Path, @@ -491,19 +523,28 @@ impl SignedDelegatedTargets { target_filename: Option<&TargetName>, ) -> Result<()> { ensure!( - input_path.is_file(), + is_file(input_path).await, error::PathIsNotFileSnafu { path: input_path } ); - match self.target_path(input_path, outdir, target_filename)? { + match self + .target_path(input_path, outdir, target_filename) + .await? + { TargetPath::New { path } => { - symlink(input_path, &path).context(error::LinkCreateSnafu { path })?; + symlink(input_path, &path) + .await + .context(error::LinkCreateSnafu { path })?; } TargetPath::Symlink { path } => match replace_behavior { PathExists::Skip => {} PathExists::Fail => error::PathExistsFailSnafu { path }.fail()?, PathExists::Replace => { - fs::remove_file(&path).context(error::RemoveTargetSnafu { path: &path })?; - symlink(input_path, &path).context(error::LinkCreateSnafu { path })?; + remove_file(&path) + .await + .context(error::RemoveTargetSnafu { path: &path })?; + symlink(input_path, &path) + .await + .context(error::LinkCreateSnafu { path })?; } }, TargetPath::File { path } => { @@ -525,7 +566,7 @@ impl SignedDelegatedTargets { /// with a different hash, or if it has the same hash but is not a regular file. Using the /// `replace_behavior` parameter, you can decide what happens if it exists with the same hash /// and file type - skip, fail, or replace. - pub fn copy_target( + pub async fn copy_target( &self, input_path: &Path, outdir: &Path, @@ -533,19 +574,28 @@ impl SignedDelegatedTargets { target_filename: Option<&TargetName>, ) -> Result<()> { ensure!( - input_path.is_file(), + is_file(input_path).await, error::PathIsNotFileSnafu { path: input_path } ); - match self.target_path(input_path, outdir, target_filename)? { + match self + .target_path(input_path, outdir, target_filename) + .await? + { TargetPath::New { path } => { - fs::copy(input_path, &path).context(error::FileWriteSnafu { path })?; + copy(input_path, &path) + .await + .context(error::FileWriteSnafu { path })?; } TargetPath::File { path } => match replace_behavior { PathExists::Skip => {} PathExists::Fail => error::PathExistsFailSnafu { path }.fail()?, PathExists::Replace => { - fs::remove_file(&path).context(error::RemoveTargetSnafu { path: &path })?; - fs::copy(input_path, &path).context(error::FileWriteSnafu { path })?; + remove_file(&path) + .await + .context(error::RemoveTargetSnafu { path: &path })?; + copy(input_path, &path) + .await + .context(error::FileWriteSnafu { path })?; } }, TargetPath::Symlink { path } => { @@ -578,10 +628,27 @@ impl TargetsWalker for SignedDelegatedTargets { } } +/// Wrapper trait to help with HKTB lifetimes +trait WalkOperator: + FnMut(S, In, Out, PathExists, Option) -> >::Fut +{ + type Fut: Future>::Output> + Send; + type Output; +} +impl WalkOperator for F +where + F: FnMut(S, In, Out, PathExists, Option) -> Fut, + Fut: Future + Send, +{ + type Fut = Fut; + type Output = Fut::Output; +} + /// `TargetsWalker` is used to unify the logic related to copying and linking targets. /// `TargetsWalker`'s default implementation of `walk_targets()` and `target_path()` use /// the trait's `targets()` and `consistent_snapshot()` methods to get a map of targets and /// also determine if a file prefix needs to be used. +#[async_trait] trait TargetsWalker { /// Returns a map of all targets this manager is responsible for fn targets(&self) -> HashMap; @@ -591,26 +658,46 @@ trait TargetsWalker { /// Walks a given directory and calls the provided function with every file found. /// The function is given the file path, the output directory where the user expects /// it to go, and optionally a desired filename. - fn walk_targets( + async fn walk_targets( &self, indir: &Path, outdir: &Path, - f: F, + mut f: F, replace_behavior: PathExists, ) -> Result<()> where - F: Fn(&Self, &Path, &Path, PathExists, Option<&TargetName>) -> Result<()>, + F: for<'a, 'b, 'c, 'd> WalkOperator< + &'a Self, + &'b Path, + &'c Path, + &'d TargetName, + Output = Result<()>, + > + Send, { - std::fs::create_dir_all(outdir).context(error::DirCreateSnafu { path: outdir })?; + create_dir_all(outdir) + .await + .context(error::DirCreateSnafu { path: outdir })?; // Get the absolute path of the indir and outdir - let abs_indir = - std::fs::canonicalize(indir).context(error::AbsolutePathSnafu { path: indir })?; + let abs_indir = canonicalize(indir) + .await + .context(error::AbsolutePathSnafu { path: indir })?; // Walk the absolute path of the indir. Using the absolute path here // means that `entry.path()` call will return its absolute path. - let walker = WalkDir::new(&abs_indir).follow_links(true); - for entry in walker { + let (tx, mut rx) = tokio::sync::mpsc::channel(10); + let root = abs_indir.clone(); + tokio::task::spawn_blocking(move || { + let walker = WalkDir::new(&root).follow_links(true); + for entry in walker { + if tx.blocking_send(entry).is_err() { + // Receiver error'ed out + break; + }; + } + }); + + while let Some(entry) = rx.recv().await { let entry = entry.context(error::WalkDirSnafu { directory: &abs_indir, })?; @@ -621,7 +708,7 @@ trait TargetsWalker { }; // Call the requested function to manipulate the path we found - if let Err(e) = f(self, entry.path(), outdir, replace_behavior, None) { + if let Err(e) = f(self, entry.path(), outdir, replace_behavior, None).await { match e { // If we found a path that isn't a known target in the repo, skip it. error::Error::PathIsNotTarget { .. } => continue, @@ -636,14 +723,15 @@ trait TargetsWalker { /// Determines the output path of a target based on consistent snapshot rules. Returns Err if /// the target already exists in the repo with a different hash, or if the target is not known /// to the repo. (We're dealing with a signed repo, so it's too late to add targets.) - fn target_path( + async fn target_path( &self, input: &Path, outdir: &Path, target_filename: Option<&TargetName>, ) -> Result { - let outdir = - std::fs::canonicalize(outdir).context(error::AbsolutePathSnafu { path: outdir })?; + let outdir = tokio::fs::canonicalize(outdir) + .await + .context(error::AbsolutePathSnafu { path: outdir })?; // If the caller requested a specific target filename, use that, otherwise use the filename // component of the input path. @@ -660,8 +748,9 @@ trait TargetsWalker { }; // create a Target object using the input path. - let target_from_path = - Target::from_path(input).context(error::TargetFromPathSnafu { path: input })?; + let target_from_path = Target::from_path(input) + .await + .context(error::TargetFromPathSnafu { path: input })?; // Use the file name to see if a target exists in the repo // with that name. If so... @@ -701,25 +790,27 @@ trait TargetsWalker { // unique; if we're not, then there could be a target from another repo with the same name // but different checksum. We can't assume such conflicts are OK, so we fail. if !self.consistent_snapshot() { - // Use DigestAdapter to get a streaming checksum of the file without needing to hold - // its contents. - let f = fs::File::open(&dest).context(error::FileOpenSnafu { path: &dest })?; - let mut reader = DigestAdapter::sha256( - Box::new(f), - &repo_target.hashes.sha256, - Url::from_file_path(&dest) - .ok() // dump unhelpful `()` error - .context(error::FileUrlSnafu { path: &dest })?, - ); - let mut dev_null = std::io::sink(); + let url = Url::from_file_path(&dest) + .ok() // dump unhelpful `()` error + .context(error::FileUrlSnafu { path: &dest })?; + + let stream = FilesystemTransport + .fetch(url.clone()) + .await + .with_context(|_| error::TransportSnafu { url: url.clone() })?; + let stream = DigestAdapter::sha256(stream, &repo_target.hashes.sha256, url.clone()); + // The act of reading with the DigestAdapter verifies the checksum, assuming the read // succeeds. - std::io::copy(&mut reader, &mut dev_null) - .context(error::FileReadSnafu { path: &dest })?; + stream + .try_for_each(|_| ready(Ok(()))) + .await + .context(error::TransportSnafu { url })?; } - let metadata = - fs::symlink_metadata(&dest).context(error::FileMetadataSnafu { path: &dest })?; + let metadata = symlink_metadata(&dest) + .await + .context(error::FileMetadataSnafu { path: &dest })?; if metadata.file_type().is_file() { Ok(TargetPath::File { path: dest }) } else if metadata.file_type().is_symlink() { diff --git a/tough/src/editor/targets.rs b/tough/src/editor/targets.rs index 47e08b76..1cbb0cbb 100644 --- a/tough/src/editor/targets.rs +++ b/tough/src/editor/targets.rs @@ -14,7 +14,7 @@ use crate::schema::{ DelegatedRole, DelegatedTargets, Delegations, KeyHolder, PathSet, RoleType, Signed, Target, Targets, }; -use crate::transport::Transport; +use crate::transport::{IntoVec, Transport}; use crate::{encode_filename, Limits}; use crate::{Repository, TargetName}; use chrono::{DateTime, Utc}; @@ -200,7 +200,7 @@ impl TargetsEditor { /// no multithreading or parallelism is used. If you have a large number /// of targets to add, and require advanced performance, you may want to /// construct `Target`s directly in parallel and use `add_target()`. - pub fn add_target_path

(&mut self, target_path: P) -> Result<&mut Self> + pub async fn add_target_path

(&mut self, target_path: P) -> Result<&mut Self> where P: AsRef, { @@ -217,6 +217,7 @@ impl TargetsEditor { // Build a Target from the path given. If it is not a file, this will fail let target = Target::from_path(target_path) + .await .context(error::TargetFromPathSnafu { path: target_path })?; self.add_target(target_name, target)?; @@ -226,12 +227,12 @@ impl TargetsEditor { /// Add a list of target paths to the targets /// /// See the note on `add_target_path()` regarding performance. - pub fn add_target_paths

(&mut self, targets: Vec

) -> Result<&mut Self> + pub async fn add_target_paths

(&mut self, targets: Vec

) -> Result<&mut Self> where P: AsRef, { for target in targets { - self.add_target_path(target)?; + self.add_target_path(target).await?; } Ok(self) } @@ -382,7 +383,7 @@ impl TargetsEditor { /// Adds a role to `new_roles` using a metadata file located at `metadata_url`/`name`.json /// `add_role()` uses `delegate_role()` to add a role from an existing metadata file. - pub fn add_role( + pub async fn add_role( &mut self, name: &str, metadata_url: &str, @@ -409,15 +410,20 @@ impl TargetsEditor { filename: encoded_filename, url: metadata_base_url, })?; - let reader = Box::new(fetch_max_size( + let stream = fetch_max_size( transport, - role_url, + role_url.clone(), limits.max_targets_size, "max targets limit", - )?); + ) + .await?; + let data = stream + .into_vec() + .await + .context(error::TransportSnafu { url: role_url })?; // Load incoming role metadata as Signed let role: Signed = - serde_json::from_reader(reader).context(error::ParseMetadataSnafu { + serde_json::from_slice(&data).context(error::ParseMetadataSnafu { role: RoleType::Targets, })?; @@ -489,7 +495,7 @@ impl TargetsEditor { } /// Creates a `KeyHolder` to sign the `Targets` role with the signing keys provided - fn create_key_holder(&self, keys: &[Box]) -> Result { + async fn create_key_holder(&self, keys: &[Box]) -> Result { // There isn't a KeyHolder, so create one based on the provided keys let mut delegations = Delegations::new(); // First create the tuf key pairs and keyids @@ -498,6 +504,7 @@ impl TargetsEditor { for source in keys { let key_pair = source .as_sign() + .await .context(error::KeyPairFromKeySourceSnafu)? .tuf_key(); key_pairs.insert( @@ -533,36 +540,37 @@ impl TargetsEditor { /// like `sign()` creates. `SignedDelegatedTargets` can contain more than 1 `Signed` /// `create_signed()` guarantees that only 1 `Signed` is created and that it is the one representing /// the current targets. `create_signed()` should be used whenever the result of `TargetsEditor` is not being written. - pub fn create_signed(&self, keys: &[Box]) -> Result> { + pub async fn create_signed( + &self, + keys: &[Box], + ) -> Result> { let rng = SystemRandom::new(); let key_holder = if let Some(key_holder) = self.key_holder.as_ref() { key_holder.clone() } else { - self.create_key_holder(keys)? + self.create_key_holder(keys).await? }; // create a signed role for the targets being edited - let targets = self - .build_targets() - .and_then(|targets| SignedRole::new(targets, &key_holder, keys, &rng))?; + let targets = self.build_targets()?; + let targets = SignedRole::new(targets, &key_holder, keys, &rng).await?; Ok(targets.signed) } /// Creates a `SignedDelegatedTargets` for the Targets role being edited and all added roles /// If `key_holder` was not assigned then this is a newly created role and needs to be signed with a /// custom delegations as its `key_holder` - pub fn sign(&self, keys: &[Box]) -> Result { + pub async fn sign(&self, keys: &[Box]) -> Result { let rng = SystemRandom::new(); let mut roles = Vec::new(); let key_holder = if let Some(key_holder) = self.key_holder.as_ref() { key_holder.clone() } else { - self.create_key_holder(keys)? + self.create_key_holder(keys).await? }; // create a signed role for the targets we are editing - let signed_targets = self - .build_targets() - .and_then(|targets| SignedRole::new(targets, &key_holder, keys, &rng))?; + let signed_targets = self.build_targets()?; + let signed_targets = SignedRole::new(signed_targets, &key_holder, keys, &rng).await?; roles.push(signed_targets); // create signed roles for any role metadata we added to this targets if let Some(new_roles) = &self.new_roles { diff --git a/tough/src/editor/test.rs b/tough/src/editor/test.rs index e821ab72..5f771df9 100644 --- a/tough/src/editor/test.rs +++ b/tough/src/editor/test.rs @@ -47,20 +47,20 @@ mod tests { } // Make sure we can't create a repo without any data - #[test] - fn empty_repository() { + #[tokio::test] + async fn empty_repository() { let root_key = key_path(); let key_source = LocalKeySource { path: root_key }; let root_path = root_path(); - let editor = RepositoryEditor::new(root_path).unwrap(); - assert!(editor.sign(&[Box::new(key_source)]).is_err()); + let editor = RepositoryEditor::new(root_path).await.unwrap(); + assert!(editor.sign(&[Box::new(key_source)]).await.is_err()); } // Make sure we can add targets from different sources #[allow(clippy::similar_names)] - #[test] - fn add_targets_from_multiple_sources() { + #[tokio::test] + async fn add_targets_from_multiple_sources() { let targets: Signed = serde_json::from_str(include_str!( "../../tests/data/tuf-reference-impl/metadata/targets.json" )) @@ -68,22 +68,23 @@ mod tests { let target3_path = targets_path().join("file3.txt"); let target2_path = targets_path().join("file2.txt"); // Use file2.txt to create a "new" target - let target4 = Target::from_path(target2_path).unwrap(); + let target4 = Target::from_path(target2_path).await.unwrap(); let root_path = tuf_root_path(); - let mut editor = RepositoryEditor::new(root_path).unwrap(); + let mut editor = RepositoryEditor::new(root_path).await.unwrap(); editor .targets(targets) .unwrap() .add_target(TargetName::new("file4.txt").unwrap(), target4) .unwrap() .add_target_path(target3_path) + .await .unwrap(); } #[allow(clippy::similar_names)] - #[test] - fn clear_targets() { + #[tokio::test] + async fn clear_targets() { let targets: Signed = serde_json::from_str(include_str!( "../../tests/data/tuf-reference-impl/metadata/targets.json" )) @@ -91,19 +92,20 @@ mod tests { let target3 = targets_path().join("file3.txt"); let root_path = tuf_root_path(); - let mut editor = RepositoryEditor::new(root_path).unwrap(); + let mut editor = RepositoryEditor::new(root_path).await.unwrap(); editor .targets(targets) .unwrap() .add_target_path(target3) + .await .unwrap() .clear_targets() .unwrap(); } // Create and fully sign a repo - #[test] - fn complete_repository() { + #[tokio::test] + async fn complete_repository() { let root = root_path(); let root_key = key_path(); let key_source = LocalKeySource { path: root_key }; @@ -118,7 +120,7 @@ mod tests { let target3 = targets_path().join("file3.txt"); let target_list = vec![target1, target2, target3]; - let mut editor = RepositoryEditor::new(root).unwrap(); + let mut editor = RepositoryEditor::new(root).await.unwrap(); editor .targets_expires(targets_expiration) .unwrap() @@ -129,13 +131,14 @@ mod tests { .timestamp_expires(timestamp_expiration) .timestamp_version(timestamp_version) .add_target_paths(target_list) + .await .unwrap(); - assert!(editor.sign(&[Box::new(key_source)]).is_ok()); + assert!(editor.sign(&[Box::new(key_source)]).await.is_ok()); } // Make sure we can add existing role structs and the proper data is kept. - #[test] - fn existing_roles() { + #[tokio::test] + async fn existing_roles() { let targets: Signed = serde_json::from_str(include_str!( "../../tests/data/tuf-reference-impl/metadata/targets.json" )) @@ -150,7 +153,7 @@ mod tests { .unwrap(); let root_path = tuf_root_path(); - let mut editor = RepositoryEditor::new(root_path).unwrap(); + let mut editor = RepositoryEditor::new(root_path).await.unwrap(); editor .targets(targets) .unwrap() diff --git a/tough/src/error.rs b/tough/src/error.rs index 345747b9..a97f5e14 100644 --- a/tough/src/error.rs +++ b/tough/src/error.rs @@ -71,9 +71,6 @@ pub enum Error { backtrace: Backtrace, }, - #[snafu(display("Failure to obtain a lock in the system_time function: {}", message))] - DatastoreTimeLock { message: String }, - #[snafu(display("Failed to create directory '{}': {}", path.display(), source))] DirCreate { path: PathBuf, @@ -649,10 +646,3 @@ pub enum Error { #[snafu(display("The targets editor was not cleared"))] TargetsEditorSome, } - -// used in `std::io::Read` implementations -impl From for std::io::Error { - fn from(err: Error) -> Self { - Self::new(std::io::ErrorKind::Other, err) - } -} diff --git a/tough/src/fetch.rs b/tough/src/fetch.rs index 989737a0..cdebf736 100644 --- a/tough/src/fetch.rs +++ b/tough/src/fetch.rs @@ -2,43 +2,33 @@ // SPDX-License-Identifier: MIT OR Apache-2.0 use crate::error::{self, Result}; -use crate::io::{DigestAdapter, MaxSizeAdapter}; -use crate::transport::Transport; +use crate::io::{max_size_adapter, DigestAdapter}; +use crate::transport::{Transport, TransportStream}; use snafu::ResultExt; -use std::io::Read; use url::Url; -pub(crate) fn fetch_max_size<'a>( - transport: &'a dyn Transport, +pub(crate) async fn fetch_max_size( + transport: &dyn Transport, url: Url, max_size: u64, specifier: &'static str, -) -> Result { - Ok(MaxSizeAdapter::new( - transport - .fetch(url.clone()) - .context(error::TransportSnafu { url })?, - specifier, - max_size, - )) +) -> Result { + let stream = transport + .fetch(url.clone()) + .await + .with_context(|_| error::TransportSnafu { url: url.clone() })?; + + let stream = max_size_adapter(stream, url, max_size, specifier); + Ok(stream) } -pub(crate) fn fetch_sha256<'a>( - transport: &'a dyn Transport, +pub(crate) async fn fetch_sha256( + transport: &dyn Transport, url: Url, size: u64, specifier: &'static str, sha256: &[u8], -) -> Result { - Ok(DigestAdapter::sha256( - Box::new(MaxSizeAdapter::new( - transport - .fetch(url.clone()) - .context(error::TransportSnafu { url: url.clone() })?, - specifier, - size, - )), - sha256, - url, - )) +) -> Result { + let stream = fetch_max_size(transport, url.clone(), size, specifier).await?; + Ok(DigestAdapter::sha256(stream, sha256, url)) } diff --git a/tough/src/http.rs b/tough/src/http.rs index b7e9eb80..ac348d39 100644 --- a/tough/src/http.rs +++ b/tough/src/http.rs @@ -1,14 +1,21 @@ //! The `http` module provides `HttpTransport` which enables `Repository` objects to be //! loaded over HTTP +use crate::transport::TransportStream; use crate::{Transport, TransportError, TransportErrorKind}; -use log::{debug, error, trace}; -use reqwest::blocking::{Client, ClientBuilder, Request, Response}; +use async_trait::async_trait; +use futures::{FutureExt, StreamExt}; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_core::Stream; +use log::trace; use reqwest::header::{self, HeaderValue, ACCEPT_RANGES}; +use reqwest::{Client, ClientBuilder, Request, Response}; use reqwest::{Error, Method}; use snafu::ResultExt; use snafu::Snafu; use std::cmp::Ordering; -use std::io::Read; +use std::pin::Pin; +use std::task::Poll; use std::time::Duration; use url::Url; @@ -125,93 +132,224 @@ pub struct HttpTransport { } /// Implement the `tough` `Transport` trait for `HttpRetryTransport` +#[async_trait] impl Transport for HttpTransport { - /// Send a GET request to the URL. Request will be retried per the `ClientSettings`. The - /// returned `RetryRead` will also retry as necessary per the `ClientSettings`. - fn fetch(&self, url: Url) -> Result, TransportError> { - let mut r = RetryState::new(self.settings.initial_backoff); - Ok(Box::new( - fetch_with_retries(&mut r, &self.settings, &url) - .map_err(|e| TransportError::from((url, e)))?, - )) + /// Send a GET request to the URL. The returned `TransportStream` will retry as necessary per + /// the `ClientSettings`. + async fn fetch(&self, url: Url) -> Result { + let r = RetryState::new(self.settings.initial_backoff); + Ok(fetch_with_retries(r, &self.settings, &url).boxed()) + } +} + +enum RequestState { + /// A response is streaming. + Streaming(BoxStream<'static, reqwest::Result>), + /// A request is pending. + Pending(BoxFuture<'static, reqwest::Result>), + /// No ongoing request. + None, +} + +impl std::fmt::Debug for RequestState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RequestState::Streaming(_) => f.write_str("Streaming"), + RequestState::Pending(_) => f.write_str("Executing"), + RequestState::None => f.write_str("None"), + } } } -/// This serves as a `Read`, but carries with it the necessary information to do retries. #[derive(Debug)] -pub struct RetryRead { +pub(crate) struct RetryStream { retry_state: RetryState, settings: HttpTransportBuilder, - response: Response, url: Url, + request: RequestState, + done: bool, + has_range_support: bool, } -impl Read for RetryRead { - /// Read bytes into `buf`, retrying as necessary. - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - // retry loop - loop { - let retry_err = match self.response.read(buf) { - Ok(sz) => { - self.retry_state.next_byte += sz; - return Ok(sz); - } - // store the error in `retry_err` to return later if there are no more retries - Err(err) => err, - }; - debug!("error during read of '{}': {:?}", self.url, retry_err); - - // increment the `retry_state` and fetch a new reader if retries are not exhausted - if self.retry_state.current_try >= self.settings.tries - 1 { - // we are out of retries, so return the last known error. - return Err(retry_err); - } - self.retry_state.increment(&self.settings); - self.err_if_no_range_support(retry_err)?; - // wait, then retry the request (with a range header). - std::thread::sleep(self.retry_state.wait); - let new_retry_read = - fetch_with_retries(&mut self.retry_state, &self.settings, &self.url) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - // the new fetch succeeded so we need to replace our read object with the new one. - self.response = new_retry_read.response; +impl Stream for RetryStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if self.done { + return Poll::Ready(None); } + + self.poll_streaming(cx) + .or_else(|| self.poll_executing(cx)) + .unwrap_or_else(|| match self.poll_new_request(cx) { + Ok(poll) => poll, + Err(e) => Poll::Ready(Some(Err((self.url.clone(), e).into()))), + }) } } -impl RetryRead { - /// Checks for the header `Accept-Ranges: bytes` - fn supports_range(&self) -> bool { - if let Some(ranges) = self.response.headers().get(ACCEPT_RANGES) { - if let Ok(val) = ranges.to_str() { - if val.contains("bytes") { - return true; +impl RetryStream { + pub fn poll_err(&mut self, error: E) -> Poll>> + where + E: Into>, + { + self.done = true; + Poll::Ready(Some(Err(TransportError::new_with_cause( + TransportErrorKind::Other, + self.url.clone(), + error, + )))) + } + + fn poll_streaming( + self: &mut Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Option::Item>>> { + let RequestState::Streaming(stream) = &mut self.request else { + return None; + }; + let next = stream.as_mut().poll_next(cx); + match next { + // Success. End stream. + Poll::Ready(None) => { + self.done = true; + Poll::Ready(None) + } + // New chunk received, keep track of position for potential recovery. + Poll::Ready(Some(Ok(data))) => { + self.retry_state.next_byte += data.len(); + Poll::Ready(Some(Ok(data))) + } + // Error while streaming the response body. Try to recover. + Poll::Ready(Some(Err(err))) => match ErrorClass::from(err) { + ErrorClass::Fatal(e) => self.poll_err(e), + ErrorClass::FileNotFound(_) => unreachable!("streaming the response body already"), + ErrorClass::Retryable(e) => { + if self.may_retry() { + match self.poll_new_request(cx) { + Ok(poll) => poll, + Err(_) => self.poll_err(e), + } + } else { + self.poll_err(e) + } + } + }, + Poll::Pending => Poll::Pending, + } + .into() + } + + fn poll_executing( + self: &mut Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Option::Item>>> { + let RequestState::Pending(request) = &mut self.request else { + return None; + }; + match request.as_mut().poll(cx) { + Poll::Ready(response) => { + let http_result: HttpResult = response.into(); + match http_result { + HttpResult::Ok(response) => { + trace!("{:?} - returning from successful fetch", self.retry_state); + if let Some(ranges) = response.headers().get(ACCEPT_RANGES) { + if let Ok(val) = ranges.to_str() { + if val.contains("bytes") { + self.has_range_support = true; + } + } + } + self.request = RequestState::Streaming(response.bytes_stream().boxed()); + cx.waker().wake_by_ref(); + Poll::Pending + } + HttpResult::Err(ErrorClass::Fatal(e)) => { + trace!( + "{:?} - returning fatal error from fetch: {}", + self.retry_state, + e + ); + self.poll_err(e) + } + HttpResult::Err(ErrorClass::FileNotFound(e)) => { + trace!( + "{:?} - returning file not found from fetch: {}", + self.retry_state, + e + ); + self.done = true; + Poll::Ready(Some(Err(TransportError::new_with_cause( + TransportErrorKind::FileNotFound, + self.url.clone(), + e, + )))) + } + HttpResult::Err(ErrorClass::Retryable(e)) => { + trace!("{:?} - retryable error: {}", self.retry_state, e); + if self.may_retry() { + match self.poll_new_request(cx) { + Ok(poll) => poll, + Err(_) => self.poll_err(e), + } + } else { + self.poll_err(e) + } + } } } + Poll::Pending => Poll::Pending, } - false + .into() } + /// Check all criteria for a retry and account for it. + fn may_retry(&mut self) -> bool { + let tries_left = self + .settings + .tries + .saturating_sub(self.retry_state.current_try); - /// Returns an error when we have received an error during read, but our server does not support - /// range headers. Our retry implementation considers this a fatal condition rather that trying - /// to start over from the beginning and advancing the `Read` to the point where failure - /// occurred. - fn err_if_no_range_support(&self, e: std::io::Error) -> std::io::Result<()> { - if !self.supports_range() { - // we cannot send a byte range request to this server, so return the error - error!( - "an error occurred and we cannot retry because the server \ - does not support range requests '{}': {:?}", - self.url, e - ); - return Err(e); + self.retry_state.increment(&self.settings); + + tries_left > 0 && (self.has_range_support || self.retry_state.next_byte == 0) + } + + /// Move to `RequestState::Executing`. + fn poll_new_request( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Result>>, HttpError> { + // create a reqwest client + let client = ClientBuilder::new() + .timeout(self.settings.timeout) + .connect_timeout(self.settings.connect_timeout) + .build() + .context(HttpClientSnafu)?; + + // build the request + let request = build_request(&client, self.retry_state.next_byte, &self.url)?; + + let backoff = self.retry_state.wait; + + let delayed_request = async move { + tokio::time::sleep(backoff).await; + client.execute(request).await } - Ok(()) + .boxed(); + + self.request = RequestState::Pending(delayed_request); + + // start polling the new request + cx.waker().wake_by_ref(); + Ok(Poll::Pending) } } /// A private struct that serves as the retry counter. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Debug)] struct RetryState { /// The current try we are on. First try is zero. current_try: u32, @@ -251,65 +389,30 @@ impl RetryState { } /// Sends a `GET` request to the `url`. Retries the request as necessary per the `ClientSettings`. -fn fetch_with_retries( - r: &mut RetryState, - cs: &HttpTransportBuilder, - url: &Url, -) -> Result { +fn fetch_with_retries(r: RetryState, cs: &HttpTransportBuilder, url: &Url) -> RetryStream { trace!("beginning fetch for '{}'", url); - // create a reqwest client - let client = ClientBuilder::new() - .timeout(cs.timeout) - .connect_timeout(cs.connect_timeout) - .build() - .context(HttpClientSnafu)?; - - // retry loop - loop { - // build the request - let request = build_request(&client, r.next_byte, url)?; - - // send the GET request, then categories the outcome by converting to an HttpResult. - let http_result: HttpResult = client.execute(request).into(); - - match http_result { - HttpResult::Ok(response) => { - trace!("{:?} - returning from successful fetch", r); - return Ok(RetryRead { - retry_state: *r, - settings: *cs, - response, - url: url.clone(), - }); - } - HttpResult::Fatal(err) => { - trace!("{:?} - returning fatal error from fetch: {}", r, err); - return Err(err).context(FetchFatalSnafu); - } - HttpResult::FileNotFound(err) => { - trace!("{:?} - returning file not found from fetch: {}", r, err); - return Err(err).context(FetchFileNotFoundSnafu); - } - HttpResult::Retryable(err) => { - trace!("{:?} - retryable error: {}", r, err); - if r.current_try >= cs.tries - 1 { - debug!("{:?} - returning failure, no more retries: {}", r, err); - return Err(err).context(FetchNoMoreRetriesSnafu { tries: cs.tries }); - } - } - } - r.increment(cs); - std::thread::sleep(r.wait); + RetryStream { + retry_state: r, + settings: *cs, + url: url.clone(), + request: RequestState::None, + done: false, + has_range_support: false, } } +/// A newtype result for ergonomic conversions. +enum HttpResult { + Ok(reqwest::Response), + Err(ErrorClass), +} + +/// Group reqwest errors into interesting cases. /// Much of the complexity in the `fetch_with_retries` function is in deciphering the `Result` /// we get from `reqwest::Client::execute`. Using this enum we categorize the states of the /// `Result` into the categories that we need to understand. -enum HttpResult { - /// We got a response with an HTTP code that indicates success. - Ok(reqwest::blocking::Response), +enum ErrorClass { /// We got an `Error` (other than file-not-found) which we will not retry. Fatal(reqwest::Error), /// The file could not be found (HTTP status 403 or 404). @@ -320,7 +423,7 @@ enum HttpResult { /// Takes the `Result` type from `reqwest::Client::execute`, and categorizes it into an /// `HttpResult` variant. -impl From> for HttpResult { +impl From> for HttpResult { fn from(result: Result) -> Self { match result { Ok(response) => { @@ -328,28 +431,33 @@ impl From> for HttpResult { // checks the status code of the response for errors parse_response_code(response) } - Err(e) if e.is_timeout() => { - // a connection timeout occurred - trace!("timeout error during fetch: {}", e); - HttpResult::Retryable(e) - } - Err(e) if e.is_request() => { - // an error occurred while sending the request - trace!("error sending request during fetch: {}", e); - HttpResult::Retryable(e) - } - Err(e) => { - // the error is not from an HTTP status code or a timeout, retries will not succeed. - // these appear to be internal, reqwest errors and are expected to be unlikely. - trace!("internal reqwest error during fetch: {}", e); - HttpResult::Fatal(e) - } + Err(e) => Self::Err(e.into()), + } + } +} + +/// Catergorize a `request::Error` into a `HttpResult` variant. +impl From for ErrorClass { + fn from(err: reqwest::Error) -> Self { + if err.is_timeout() { + // a connection timeout occurred + trace!("timeout error during fetch: {}", err); + ErrorClass::Retryable(err) + } else if err.is_request() { + // an error occurred while sending the request + trace!("error sending request during fetch: {}", err); + ErrorClass::Retryable(err) + } else { + // the error is not from an HTTP status code or a timeout, retries will not succeed. + // these appear to be internal, reqwest errors and are expected to be unlikely. + trace!("internal reqwest error during fetch: {}", err); + ErrorClass::Fatal(err) } } } /// Checks the HTTP response code and converts a non-successful response code to an error. -fn parse_response_code(response: reqwest::blocking::Response) -> HttpResult { +fn parse_response_code(response: reqwest::Response) -> HttpResult { match response.error_for_status() { Ok(ok) => { trace!("response is success"); @@ -362,19 +470,19 @@ fn parse_response_code(response: reqwest::blocking::Response) -> HttpResult { // this shouldn't happen, we received this err from the err_for_status function, // so the error should have a status. we cannot consider this a retryable error. trace!("error is fatal (no status): {}", err); - HttpResult::Fatal(err) + HttpResult::Err(ErrorClass::Fatal(err)) } Some(status) if status.is_server_error() => { trace!("error is retryable: {}", err); - HttpResult::Retryable(err) + HttpResult::Err(ErrorClass::Retryable(err)) } Some(status) if matches!(status.as_u16(), 403 | 404 | 410) => { trace!("error is file not found: {}", err); - HttpResult::FileNotFound(err) + HttpResult::Err(ErrorClass::FileNotFound(err)) } Some(_) => { trace!("error is fatal (status): {}", err); - HttpResult::Fatal(err) + HttpResult::Err(ErrorClass::Fatal(err)) } }, } diff --git a/tough/src/io.rs b/tough/src/io.rs index 97c2afe8..f34be54b 100644 --- a/tough/src/io.rs +++ b/tough/src/io.rs @@ -1,133 +1,160 @@ // Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: MIT OR Apache-2.0 -use crate::error; +use crate::{error, transport::TransportStream, TransportError}; +use futures::StreamExt; +use futures_core::Stream; use ring::digest::{Context, SHA256}; -use std::io::{self, Read}; +use std::{convert::TryInto, path::Path, task::Poll}; +use tokio::fs; use url::Url; -pub(crate) struct DigestAdapter<'a> { +pub(crate) struct DigestAdapter { url: Url, - reader: Box, + stream: TransportStream, hash: Vec, - digest: Option, + digest: Context, } -impl<'a> DigestAdapter<'a> { - pub(crate) fn sha256(reader: Box, hash: &[u8], url: Url) -> Self { +impl DigestAdapter { + pub(crate) fn sha256(stream: TransportStream, hash: &[u8], url: Url) -> TransportStream { Self { url, - reader, + stream, hash: hash.to_owned(), - digest: Some(Context::new(&SHA256)), + digest: Context::new(&SHA256), } + .boxed() } } -impl<'a> Read for DigestAdapter<'a> { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - assert!( - self.digest.is_some(), - "DigestAdapter::read called after end of file" - ); +impl Stream for DigestAdapter { + type Item = ::Item; - let size = self.reader.read(buf)?; - if size == 0 { - let result = self.digest.take().unwrap().finish(); - if result.as_ref() != self.hash.as_slice() { - error::HashMismatchSnafu { - context: self.url.to_string(), - calculated: hex::encode(result), - expected: hex::encode(&self.hash), + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let poll = self.stream.as_mut().poll_next(cx); + match &poll { + Poll::Ready(Some(Ok(bytes))) => { + self.digest.update(bytes); + } + Poll::Ready(None) => { + let result = &self.digest.clone().finish(); + if result.as_ref() != self.hash.as_slice() { + let mismatch_err = error::HashMismatchSnafu { + context: self.url.to_string(), + calculated: hex::encode(result), + expected: hex::encode(&self.hash), + } + .build(); + return Poll::Ready(Some(Err(TransportError::new_with_cause( + crate::TransportErrorKind::Other, + self.url.clone(), + mismatch_err, + )))); } - .fail()?; } - Ok(size) - } else if let Some(digest) = &mut self.digest { - digest.update(&buf[..size]); - Ok(size) - } else { - unreachable!(); - } + Poll::Ready(Some(Err(_))) | Poll::Pending => (), + }; + + poll } } -pub(crate) struct MaxSizeAdapter<'a> { - reader: Box, - /// How the `max_size` was specified. For example the max size of `root.json` is specified by - /// the `max_root_size` argument in `Settings`. `specifier` is used to construct an error - /// message when the `MaxSizeAdapter` detects that too many bytes have been read. - specifier: &'static str, +/// Create a new stream from `stream`. The new stream returns an error for the item that exceeds the +/// total byte count of `max_size`. +/// * `stream` - The original stream. +/// * `max_size` - Size limit in bytes. +/// * `specifier` - Error message to use. +pub(crate) fn max_size_adapter( + stream: TransportStream, + url: Url, max_size: u64, - counter: u64, + specifier: &'static str, +) -> TransportStream { + let mut size: u64 = 0; + let stream = stream.map(move |chunk| { + if let Ok(bytes) = &chunk { + size = size.saturating_add(bytes.len().try_into().unwrap_or(u64::MAX)); + } + if size > max_size { + let size_err = error::MaxSizeExceededSnafu { + max_size, + specifier, + } + .build(); + return Err(TransportError::new_with_cause( + crate::TransportErrorKind::Other, + url.clone(), + size_err, + )); + } + chunk + }); + + stream.boxed() } -impl<'a> MaxSizeAdapter<'a> { - pub(crate) fn new( - reader: Box, - specifier: &'static str, - max_size: u64, - ) -> Self { - Self { - reader, - specifier, - max_size, - counter: 0, - } - } +/// Async analogue of `std::path::Path::is_file` +pub async fn is_file(path: impl AsRef) -> bool { + fs::metadata(path) + .await + .map(|m| m.is_file()) + .unwrap_or(false) } -impl<'a> Read for MaxSizeAdapter<'a> { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let size = self.reader.read(buf)?; - self.counter += size as u64; - if self.counter > self.max_size { - error::MaxSizeExceededSnafu { - max_size: self.max_size, - specifier: self.specifier, - } - .fail()?; - } - Ok(size) - } +/// Async analogue of `std::path::Path::is_dir` +pub async fn is_dir(path: impl AsRef) -> bool { + fs::metadata(path) + .await + .map(|m| m.is_dir()) + .unwrap_or(false) } #[cfg(test)] mod tests { - use crate::io::{DigestAdapter, MaxSizeAdapter}; + use crate::{ + io::{max_size_adapter, DigestAdapter}, + transport::IntoVec, + }; + use bytes::Bytes; + use futures::{stream, StreamExt}; use hex_literal::hex; - use std::io::{Cursor, Read}; use url::Url; - #[test] - fn test_max_size_adapter() { - let mut reader = MaxSizeAdapter::new(Box::new(Cursor::new(b"hello".to_vec())), "test", 5); - let mut buf = Vec::new(); - assert!(reader.read_to_end(&mut buf).is_ok()); + #[tokio::test] + async fn test_max_size_adapter() { + let url = Url::parse("file:///").unwrap(); + + let stream = stream::iter("hello".as_bytes().chunks(2).map(Bytes::from).map(Ok)).boxed(); + let stream = max_size_adapter(stream, url.clone(), 5, "test"); + let buf = stream.into_vec().await.expect("consuming entire stream"); assert_eq!(buf, b"hello"); - let mut reader = MaxSizeAdapter::new(Box::new(Cursor::new(b"hello".to_vec())), "test", 4); - let mut buf = Vec::new(); - assert!(reader.read_to_end(&mut buf).is_err()); + let stream = stream::iter("hello".as_bytes().chunks(2).map(Bytes::from).map(Ok)).boxed(); + let stream = max_size_adapter(stream, url, 4, "test"); + assert!(stream.into_vec().await.is_err()); } - #[test] - fn test_digest_adapter() { - let mut reader = DigestAdapter::sha256( - Box::new(Cursor::new(b"hello".to_vec())), + #[tokio::test] + async fn test_digest_adapter() { + let stream = stream::iter("hello".as_bytes().chunks(2).map(Bytes::from).map(Ok)).boxed(); + let stream = DigestAdapter::sha256( + stream, &hex!("2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"), Url::parse("file:///").unwrap(), ); - let mut buf = Vec::new(); - assert!(reader.read_to_end(&mut buf).is_ok()); + let buf = stream.into_vec().await.expect("consuming entire stream"); assert_eq!(buf, b"hello"); - let mut reader = DigestAdapter::sha256( - Box::new(Cursor::new(b"hello".to_vec())), + let stream = stream::iter("hello".as_bytes().chunks(2).map(Bytes::from).map(Ok)).boxed(); + let stream = DigestAdapter::sha256( + stream, &hex!("0ebdc3317b75839f643387d783535adc360ca01f33c75f7c1e7373adcd675c0b"), Url::parse("file:///").unwrap(), ); - let mut buf = Vec::new(); - assert!(reader.read_to_end(&mut buf).is_err()); + assert!(stream.into_vec().await.is_err()); } } diff --git a/tough/src/key_source.rs b/tough/src/key_source.rs index d8c2d319..b556f5fb 100644 --- a/tough/src/key_source.rs +++ b/tough/src/key_source.rs @@ -5,6 +5,7 @@ //! obtained, for example, from local files or from cloud provider key stores. use crate::error; use crate::sign::{parse_keypair, Sign}; +use async_trait::async_trait; use snafu::ResultExt; use std::fmt::Debug; use std::path::PathBuf; @@ -12,12 +13,15 @@ use std::result::Result; /// This trait should be implemented for each source of signing keys. Examples /// of sources include: files, AWS SSM, etc. +#[async_trait] pub trait KeySource: Debug + Send + Sync { /// Returns an object that implements the `Sign` trait - fn as_sign(&self) -> Result, Box>; + async fn as_sign( + &self, + ) -> Result, Box>; /// Writes a key back to the `KeySource` - fn write( + async fn write( &self, value: &str, key_id_hex: &str, @@ -32,18 +36,24 @@ pub struct LocalKeySource { } /// Implements the `KeySource` trait for a `LocalKeySource` (file) +#[async_trait] impl KeySource for LocalKeySource { - fn as_sign(&self) -> Result, Box> { - let data = std::fs::read(&self.path).context(error::FileReadSnafu { path: &self.path })?; + async fn as_sign( + &self, + ) -> Result, Box> { + let data = tokio::fs::read(&self.path) + .await + .context(error::FileReadSnafu { path: &self.path })?; Ok(Box::new(parse_keypair(&data)?)) } - fn write( + async fn write( &self, value: &str, _key_id_hex: &str, ) -> Result<(), Box> { - Ok(std::fs::write(&self.path, value.as_bytes()) + Ok(tokio::fs::write(&self.path, value.as_bytes()) + .await .context(error::FileWriteSnafu { path: &self.path })?) } } diff --git a/tough/src/lib.rs b/tough/src/lib.rs index 38e85d02..57acd5da 100644 --- a/tough/src/lib.rs +++ b/tough/src/lib.rs @@ -50,24 +50,31 @@ use crate::error::Result; use crate::fetch::{fetch_max_size, fetch_sha256}; /// An HTTP transport that includes retries. #[cfg(feature = "http")] -pub use crate::http::{HttpTransport, HttpTransportBuilder, RetryRead}; +pub use crate::http::{HttpTransport, HttpTransportBuilder}; +use crate::io::is_dir; use crate::schema::{ DelegatedRole, Delegations, Role, RoleType, Root, Signed, Snapshot, Timestamp, }; pub use crate::target_name::TargetName; +pub use crate::transport::IntoVec; pub use crate::transport::{ DefaultTransport, FilesystemTransport, Transport, TransportError, TransportErrorKind, }; pub use crate::urlpath::SafeUrlPath; +use async_recursion::async_recursion; +pub use async_trait::async_trait; +pub use bytes::Bytes; use chrono::{DateTime, Utc}; +use futures::StreamExt; +use futures_core::Stream; use log::warn; use percent_encoding::{utf8_percent_encode, AsciiSet, NON_ALPHANUMERIC}; use snafu::{ensure, OptionExt, ResultExt}; use std::collections::HashMap; -use std::fs::create_dir_all; -use std::io::Read; use std::path::{Path, PathBuf}; use tempfile::NamedTempFile; +use tokio::fs::{canonicalize, create_dir_all}; +use tokio::io::AsyncWriteExt; use url::Url; /// Represents whether a Repository should fail to load when metadata is expired (`Safe`) or whether @@ -114,7 +121,6 @@ impl From for bool { /// ## Basic usage: /// /// ```rust -/// # use std::fs::File; /// # use std::path::PathBuf; /// # use tough::RepositoryLoader; /// # use url::Url; @@ -122,21 +128,23 @@ impl From for bool { /// # let root = dir.join("metadata").join("1.root.json"); /// # let metadata_base_url = Url::from_file_path(dir.join("metadata")).unwrap(); /// # let targets_base_url = Url::from_file_path(dir.join("targets")).unwrap(); +/// # tokio_test::block_on(async { /// /// let repository = RepositoryLoader::new( -/// File::open(root).unwrap(), +/// &tokio::fs::read(root).await.unwrap(), /// metadata_base_url, /// targets_base_url, /// ) /// .load() +/// .await /// .unwrap(); /// +/// # }); /// ``` /// /// ## With optional settings: /// /// ```rust -/// # use std::fs::File; /// # use std::path::PathBuf; /// # use tough::{RepositoryLoader, FilesystemTransport, ExpirationEnforcement}; /// # use url::Url; @@ -144,24 +152,24 @@ impl From for bool { /// # let root = dir.join("metadata").join("1.root.json"); /// # let metadata_base_url = Url::from_file_path(dir.join("metadata")).unwrap(); /// # let targets_base_url = Url::from_file_path(dir.join("targets")).unwrap(); +/// # tokio_test::block_on(async { /// /// let repository = RepositoryLoader::new( -/// File::open(root).unwrap(), +/// &tokio::fs::read(root).await.unwrap(), /// metadata_base_url, /// targets_base_url, /// ) /// .transport(FilesystemTransport) /// .expiration_enforcement(ExpirationEnforcement::Unsafe) /// .load() +/// .await /// .unwrap(); /// +/// # }); /// ``` #[derive(Debug, Clone)] -pub struct RepositoryLoader -where - R: Read, -{ - root: R, +pub struct RepositoryLoader<'a> { + root: &'a [u8], metadata_base_url: Url, targets_base_url: Url, transport: Option>, @@ -170,19 +178,19 @@ where expiration_enforcement: Option, } -impl RepositoryLoader { +impl<'a> RepositoryLoader<'a> { /// Create a new `RepositoryLoader`. /// - /// `root` is a [`Read`]er for the trusted root metadata file, which you must ship with your + /// `root` is the content of a trusted root metadata file, which you must ship with your /// software using an out-of-band process. It should be a copy of the most recent root.json /// from your repository. (It's okay if it becomes out of date later; the client establishes /// trust up to the most recent root.json file.) /// /// `metadata_base_url` and `targets_base_url` are the base URLs where the client can find /// metadata (such as root.json) and targets (as listed in targets.json). - pub fn new(root: R, metadata_base_url: Url, targets_base_url: Url) -> Self { + pub fn new(root: &'a impl AsRef<[u8]>, metadata_base_url: Url, targets_base_url: Url) -> Self { Self { - root, + root: root.as_ref(), metadata_base_url, targets_base_url, transport: None, @@ -193,8 +201,8 @@ impl RepositoryLoader { } /// Load and verify TUF repository metadata. - pub fn load(self) -> Result { - Repository::load(self) + pub async fn load(self) -> Result { + Repository::load(self).await } /// Set the transport. If no transport has been set, [`DefaultTransport`] will be used. @@ -314,7 +322,7 @@ pub struct Repository { impl Repository { /// Load and verify TUF repository metadata using a [`RepositoryLoader`] for the settings. - fn load(loader: RepositoryLoader) -> Result { + async fn load(loader: RepositoryLoader<'_>) -> Result { let datastore = Datastore::new(loader.datastore)?; let transport = loader .transport @@ -333,7 +341,8 @@ impl Repository { limits.max_root_updates, &metadata_base_url, expiration_enforcement, - )?; + ) + .await?; // 2. Download the timestamp metadata file let timestamp = load_timestamp( @@ -343,7 +352,8 @@ impl Repository { limits.max_timestamp_size, &metadata_base_url, expiration_enforcement, - )?; + ) + .await?; // 3. Download the snapshot metadata file let snapshot = load_snapshot( @@ -353,7 +363,8 @@ impl Repository { &datastore, &metadata_base_url, expiration_enforcement, - )?; + ) + .await?; // 4. Download the targets metadata file let targets = load_targets( @@ -364,7 +375,8 @@ impl Repository { limits.max_targets_size, &metadata_base_url, expiration_enforcement, - )?; + ) + .await?; let expires_iter = [ (root.signed.expires, RoleType::Root), @@ -424,15 +436,19 @@ impl Repository { /// /// If the requested target is not listed in the repository metadata, `Ok(None)` is returned. /// - /// Otherwise, a reader is returned, which provides streaming access to the target contents - /// before its checksum is validated. If the maximum size is reached or there is a checksum - /// mismatch, the reader returns a [`std::io::Error`]. **Consumers of this library must not use - /// data from the reader if it returns an error.** - pub fn read_target(&self, name: &TargetName) -> Result> { + /// Otherwise, a stream is returned, which provides access to the target contents before its + /// checksum is validated. If the maximum size is reached or there is a checksum mismatch, the + /// stream returns a [`error::Error`]. **Consumers of this library must not use data from the + /// stream if it returns an error.** + pub async fn read_target( + &self, + name: &TargetName, + ) -> Result> + IntoVec + Send>> + { // Check for repository metadata expiration. if self.expiration_enforcement == ExpirationEnforcement::Safe { ensure!( - self.datastore.system_time()? < self.earliest_expiration, + self.datastore.system_time().await? < self.earliest_expiration, error::ExpiredMetadataSnafu { role: self.earliest_expiration_role } @@ -458,7 +474,7 @@ impl Repository { // non-volatile storage as FILENAME.EXT. Ok(if let Ok(target) = self.targets.signed.find_target(name) { let (sha256, file) = self.target_digest_and_filename(target, name); - Some(self.fetch_target(target, &sha256, file.as_str())?) + Some(self.fetch_target(target, &sha256, file.as_str()).await?) } else { None }) @@ -481,17 +497,17 @@ impl Repository { /// - Will error if the result of path resolution results in a filepath outside of `outdir` or /// outside of a delegated target's correct path of delegation. /// - pub fn save_target

(&self, name: &TargetName, outdir: P, prepend: Prefix) -> Result<()> + pub async fn save_target

(&self, name: &TargetName, outdir: P, prepend: Prefix) -> Result<()> where P: AsRef, { // Ensure the outdir exists then canonicalize the path. let outdir = outdir.as_ref(); - let outdir = outdir - .canonicalize() + let outdir = canonicalize(outdir) + .await .context(error::SaveTargetOutdirCanonicalizeSnafu { path: outdir })?; ensure!( - outdir.is_dir(), + is_dir(&outdir).await, error::SaveTargetOutdirSnafu { path: outdir } ); @@ -541,17 +557,37 @@ impl Repository { ); // Fetch and write the target using NamedTempFile for an atomic file creation. - let mut reader = self - .read_target(name)? + let mut stream = self + .read_target(name) + .await? .with_context(|| error::SaveTargetNotFoundSnafu { name: name.clone() })?; - create_dir_all(filepath_dir).context(error::DirCreateSnafu { - path: &filepath_dir, - })?; - let mut f = - NamedTempFile::new_in(filepath_dir).context(error::NamedTempFileCreateSnafu { + create_dir_all(filepath_dir) + .await + .context(error::DirCreateSnafu { path: &filepath_dir, })?; - std::io::copy(&mut reader, &mut f).context(error::FileWriteSnafu { path: &f.path() })?; + + // Create a new temporary file. + let tmp_path = filepath_dir.to_owned(); + let tmp = tokio::task::spawn_blocking(move || NamedTempFile::new_in(tmp_path)) + .await + // We do not cancel the task nor do we expect it to panic + .unwrap_or_else(|_| unreachable!()) + .context(error::NamedTempFileCreateSnafu { path: filepath_dir })?; + + // Convert to `tokio::fs::File`. + let (f, tmp_path) = tmp.into_parts(); + let mut f = tokio::fs::File::from_std(f); + + // Write input stream to file. + while let Some(bytes) = stream.next().await { + f.write_all(bytes?.as_ref()) + .await + .context(error::FileWriteSnafu { path: &tmp_path })?; + } + + // Reconstruct `NamedTempFile` in order to persist it at the target location. + let f = NamedTempFile::from_parts(f.into_std().await, tmp_path); f.persist(&resolved_filepath) .context(error::NamedTempFilePersistSnafu { path: resolved_filepath, @@ -591,9 +627,9 @@ pub(crate) fn encode_filename>(name: S) -> String { /// TUF v1.0.16, 5.2.9, 5.3.3, 5.4.5, 5.5.4, The expiration timestamp in the `[metadata]` file MUST /// be higher than the fixed update start time. -fn check_expired(datastore: &Datastore, role: &T) -> Result<()> { +async fn check_expired(datastore: &Datastore, role: &T) -> Result<()> { ensure!( - datastore.system_time()? <= role.expires(), + datastore.system_time().await? <= role.expires(), error::ExpiredMetadataSnafu { role: T::TYPE } ); Ok(()) @@ -614,7 +650,7 @@ fn parse_url(url: Url) -> Result { /// Steps 0 and 1 of the client application, which load the current root metadata file based on a /// trusted root metadata file. -fn load_root( +async fn load_root>( transport: &dyn Transport, root: R, datastore: &Datastore, @@ -628,7 +664,7 @@ fn load_root( // that the expiration of the trusted root metadata file does not matter, because we will // attempt to update it in the next step. let mut root: Signed = - serde_json::from_reader(root).context(error::ParseTrustedMetadataSnafu)?; + serde_json::from_slice(root.as_ref()).context(error::ParseTrustedMetadataSnafu)?; root.signed .verify_role(&root) .context(error::VerifyTrustedMetadataSnafu)?; @@ -669,19 +705,29 @@ fn load_root( error::MaxUpdatesExceededSnafu { max_root_updates } ); let path = format!("{}.root.json", root.signed.version.get() + 1); + let url = metadata_base_url + .join(&path) + .with_context(|_| error::JoinUrlSnafu { + path: path.clone(), + url: metadata_base_url.clone(), + })?; match fetch_max_size( transport, - metadata_base_url.join(&path).context(error::JoinUrlSnafu { - path, - url: metadata_base_url.clone(), - })?, + url.clone(), max_root_size, "max_root_size argument", - ) { + ) + .await + { Err(_) => break, // If this file is not available, then go to step 1.8. - Ok(reader) => { + Ok(stream) => { + let data = match stream.into_vec().await { + Ok(d) => d, + Err(e) if e.kind() == TransportErrorKind::FileNotFound => break, + err @ Err(_) => err.context(error::TransportSnafu { url })?, + }; let new_root: Signed = - serde_json::from_reader(reader).context(error::ParseMetadataSnafu { + serde_json::from_slice(&data).context(error::ParseMetadataSnafu { role: RoleType::Root, })?; @@ -749,7 +795,7 @@ fn load_root( // file has expired, abort the update cycle, report the potential freeze attack. On the next // update cycle, begin at step 5.1 and version N of the root metadata file. if expiration_enforcement == ExpirationEnforcement::Safe { - check_expired(datastore, &root.signed)?; + check_expired(datastore, &root.signed).await?; } // 1.9. If the timestamp and / or snapshot keys have been rotated, then delete the trusted @@ -765,8 +811,8 @@ fn load_root( .iter() .ne(root.signed.keys(RoleType::Snapshot)) { - let r1 = datastore.remove("timestamp.json"); - let r2 = datastore.remove("snapshot.json"); + let r1 = datastore.remove("timestamp.json").await; + let r2 = datastore.remove("snapshot.json").await; r1.and(r2)?; } @@ -780,7 +826,7 @@ fn load_root( } /// Step 2 of the client application, which loads the timestamp metadata file. -fn load_timestamp( +async fn load_timestamp( transport: &dyn Transport, root: &Signed, datastore: &Datastore, @@ -793,17 +839,25 @@ fn load_timestamp( // example, Y may be tens of kilobytes. The filename used to download the timestamp metadata // file is of the fixed form FILENAME.EXT (e.g., timestamp.json). let path = "timestamp.json"; - let reader = fetch_max_size( - transport, - metadata_base_url.join(path).context(error::JoinUrlSnafu { + let url = metadata_base_url + .join(path) + .with_context(|_| error::JoinUrlSnafu { path, url: metadata_base_url.clone(), - })?, + })?; + let stream = fetch_max_size( + transport, + url.clone(), max_timestamp_size, "max_timestamp_size argument", - )?; + ) + .await?; + let data = stream + .into_vec() + .await + .context(error::TransportSnafu { url })?; let timestamp: Signed = - serde_json::from_reader(reader).context(error::ParseMetadataSnafu { + serde_json::from_slice(&data).context(error::ParseMetadataSnafu { role: RoleType::Timestamp, })?; @@ -821,8 +875,9 @@ fn load_timestamp( // file. If the new timestamp metadata file is older than the trusted timestamp metadata // file, discard it, abort the update cycle, and report the potential rollback attack. if let Some(Ok(old_timestamp)) = datastore - .reader("timestamp.json")? - .map(serde_json::from_reader::<_, Signed>) + .bytes("timestamp.json") + .await? + .map(|b| serde_json::from_slice::>(&b)) { if root.signed.verify_role(&old_timestamp).is_ok() { ensure!( @@ -841,17 +896,17 @@ fn load_timestamp( // metadata file becomes the trusted timestamp metadata file. If the new timestamp metadata file // has expired, discard it, abort the update cycle, and report the potential freeze attack. if expiration_enforcement == ExpirationEnforcement::Safe { - check_expired(datastore, ×tamp.signed)?; + check_expired(datastore, ×tamp.signed).await?; } // Now that everything seems okay, write the timestamp file to the datastore. - datastore.create("timestamp.json", ×tamp)?; + datastore.create("timestamp.json", ×tamp).await?; Ok(timestamp) } /// Step 3 of the client application, which loads the snapshot metadata file. -fn load_snapshot( +async fn load_snapshot( transport: &dyn Transport, root: &Signed, timestamp: &Signed, @@ -880,18 +935,26 @@ fn load_snapshot( } else { "snapshot.json".to_owned() }; - let reader = fetch_sha256( - transport, - metadata_base_url.join(&path).context(error::JoinUrlSnafu { - path, + let url = metadata_base_url + .join(&path) + .with_context(|_| error::JoinUrlSnafu { + path: path.clone(), url: metadata_base_url.clone(), - })?, + })?; + let stream = fetch_sha256( + transport, + url.clone(), snapshot_meta.length, "timestamp.json", &snapshot_meta.hashes.sha256, - )?; + ) + .await?; + let data = stream + .into_vec() + .await + .context(error::TransportSnafu { url })?; let snapshot: Signed = - serde_json::from_reader(reader).context(error::ParseMetadataSnafu { + serde_json::from_slice(&data).context(error::ParseMetadataSnafu { role: RoleType::Snapshot, })?; @@ -925,8 +988,9 @@ fn load_snapshot( // 3.3.1. Note that the trusted snapshot metadata file may be checked for authenticity, but its // expiration does not matter for the following purposes. if let Some(Ok(old_snapshot)) = datastore - .reader("snapshot.json")? - .map(serde_json::from_reader::<_, Signed>) + .bytes("snapshot.json") + .await? + .map(|b| serde_json::from_slice::>(&b)) { // 3.3.2. The version number of the trusted snapshot metadata file, if any, MUST be less // than or equal to the version number of the new snapshot metadata file. If the new @@ -976,17 +1040,17 @@ fn load_snapshot( // metadata file becomes the trusted snapshot metadata file. If the new snapshot metadata file // is expired, discard it, abort the update cycle, and report the potential freeze attack. if expiration_enforcement == ExpirationEnforcement::Safe { - check_expired(datastore, &snapshot.signed)?; + check_expired(datastore, &snapshot.signed).await?; } // Now that everything seems okay, write the snapshot file to the datastore. - datastore.create("snapshot.json", &snapshot)?; + datastore.create("snapshot.json", &snapshot).await?; Ok(snapshot) } /// Step 4 of the client application, which loads the targets metadata file. -fn load_targets( +async fn load_targets( transport: &dyn Transport, root: &Signed, snapshot: &Signed, @@ -1018,32 +1082,34 @@ fn load_targets( } else { "targets.json".to_owned() }; - let targets_url = metadata_base_url.join(&path).context(error::JoinUrlSnafu { - path, - url: metadata_base_url.clone(), - })?; + let targets_url = metadata_base_url + .join(&path) + .with_context(|_| error::JoinUrlSnafu { + path, + url: metadata_base_url.clone(), + })?; let (max_targets_size, specifier) = match targets_meta.length { Some(length) => (length, "snapshot.json"), None => (max_targets_size, "max_targets_size parameter"), }; - let reader = if let Some(hashes) = &targets_meta.hashes { - Box::new(fetch_sha256( + let stream = if let Some(hashes) = &targets_meta.hashes { + fetch_sha256( transport, - targets_url, + targets_url.clone(), max_targets_size, specifier, &hashes.sha256, - )?) as Box + ) + .await? } else { - Box::new(fetch_max_size( - transport, - targets_url, - max_targets_size, - specifier, - )?) + fetch_max_size(transport, targets_url.clone(), max_targets_size, specifier).await? }; + let data = stream + .into_vec() + .await + .context(error::TransportSnafu { url: targets_url })?; let mut targets: Signed = - serde_json::from_reader(reader).context(error::ParseMetadataSnafu { + serde_json::from_slice(&data).context(error::ParseMetadataSnafu { role: RoleType::Targets, })?; @@ -1077,8 +1143,9 @@ fn load_targets( // If the new targets metadata file is older than the trusted targets metadata file, discard // it, abort the update cycle, and report the potential rollback attack. if let Some(Ok(old_targets)) = datastore - .reader("targets.json")? - .map(serde_json::from_reader::<_, Signed>) + .bytes("targets.json") + .await? + .map(|b| serde_json::from_slice::>(&b)) { if root.signed.verify_role(&old_targets).is_ok() { ensure!( @@ -1097,11 +1164,11 @@ fn load_targets( // metadata file becomes the trusted targets metadata file. If the new targets metadata file is // expired, discard it, abort the update cycle, and report the potential freeze attack. if expiration_enforcement == ExpirationEnforcement::Safe { - check_expired(datastore, &targets.signed)?; + check_expired(datastore, &targets.signed).await?; } // Now that everything seems okay, write the targets file to the datastore. - datastore.create("targets.json", &targets)?; + datastore.create("targets.json", &targets).await?; // 4.5. Perform a preorder depth-first search for metadata about the desired target, beginning // with the top-level targets role. @@ -1114,7 +1181,8 @@ fn load_targets( max_targets_size, delegations, datastore, - )?; + ) + .await?; } // This validation can only be done from the top level targets.json role. This check verifies @@ -1124,7 +1192,8 @@ fn load_targets( } // Follow the paths of delegations starting with the top level targets.json delegation -fn load_delegations( +#[async_recursion] +async fn load_delegations( transport: &dyn Transport, snapshot: &Signed, consistent_snapshot: bool, @@ -1141,7 +1210,7 @@ fn load_delegations( .signed .meta .get(&format!("{}.json", &delegated_role.name)) - .context(error::RoleNotInMetaSnafu { + .with_context(|| error::RoleNotInMetaSnafu { name: delegated_role.name.clone(), })?; @@ -1154,21 +1223,23 @@ fn load_delegations( } else { format!("{}.json", encode_filename(&delegated_role.name)) }; - let role_url = metadata_base_url.join(&path).context(error::JoinUrlSnafu { - path: path.clone(), - url: metadata_base_url.clone(), - })?; + let role_url = metadata_base_url + .join(&path) + .with_context(|_| error::JoinUrlSnafu { + path: path.clone(), + url: metadata_base_url.clone(), + })?; let specifier = "max_targets_size parameter"; // load the role json file - let reader = Box::new(fetch_max_size( - transport, - role_url, - max_targets_size, - specifier, - )?); + let stream = + fetch_max_size(transport, role_url.clone(), max_targets_size, specifier).await?; + let data = stream + .into_vec() + .await + .context(error::TransportSnafu { url: role_url })?; // since each role is a targets, we load them as such let role: Signed = - serde_json::from_reader(reader).context(error::ParseMetadataSnafu { + serde_json::from_slice(&data).context(error::ParseMetadataSnafu { role: RoleType::Targets, })?; // verify each role with the delegation @@ -1186,16 +1257,17 @@ fn load_delegations( } ); - datastore.create(&path, &role)?; + datastore.create(&path, &role).await?; delegated_roles.insert(delegated_role.name.clone(), Some(role)); } // load all roles delegated by this role for delegated_role in &mut delegation.roles { - delegated_role.targets = delegated_roles.remove(&delegated_role.name).context( - error::DelegatedRolesNotConsistentSnafu { - name: delegated_role.name.clone(), - }, - )?; + delegated_role.targets = + delegated_roles + .remove(&delegated_role.name) + .with_context(|| error::DelegatedRolesNotConsistentSnafu { + name: delegated_role.name.clone(), + })?; if let Some(targets) = &mut delegated_role.targets { if let Some(delegations) = &mut targets.signed.delegations { load_delegations( @@ -1206,7 +1278,8 @@ fn load_delegations( max_targets_size, delegations, datastore, - )?; + ) + .await?; } } } diff --git a/tough/src/schema/mod.rs b/tough/src/schema/mod.rs index 13005df7..f3e2846c 100644 --- a/tough/src/schema/mod.rs +++ b/tough/src/schema/mod.rs @@ -28,12 +28,12 @@ use serde_json::Value; use serde_plain::{derive_display_from_serialize, derive_fromstr_from_deserialize}; use snafu::ResultExt; use std::collections::HashMap; -use std::fs::File; -use std::io::Read; use std::num::NonZeroU64; use std::ops::{Deref, DerefMut}; use std::path::Path; use std::str::FromStr; +use tokio::fs::File; +use tokio::io::AsyncReadExt; /// The type of metadata role. #[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash)] @@ -455,7 +455,7 @@ pub struct Target { impl Target { /// Given a path, returns a Target struct - pub fn from_path

(path: P) -> Result + pub async fn from_path

(path: P) -> Result where P: AsRef, { @@ -466,12 +466,18 @@ impl Target { } // Get the sha256 and length of the target - let mut file = File::open(path).context(error::FileOpenSnafu { path })?; + let mut file = File::open(path) + .await + .context(error::FileOpenSnafu { path })?; let mut digest = Context::new(&SHA256); let mut buf = [0; 8 * 1024]; let mut length = 0; loop { - match file.read(&mut buf).context(error::FileReadSnafu { path })? { + match file + .read(&mut buf) + .await + .context(error::FileReadSnafu { path })? + { 0 => break, n => { digest.update(&buf[..n]); diff --git a/tough/src/sign.rs b/tough/src/sign.rs index 7b889040..6c00585c 100644 --- a/tough/src/sign.rs +++ b/tough/src/sign.rs @@ -8,6 +8,7 @@ use crate::schema::key::Key; use crate::sign::SignKeyPair::ECDSA; use crate::sign::SignKeyPair::ED25519; use crate::sign::SignKeyPair::RSA; +use async_trait::async_trait; use ring::rand::SecureRandom; use ring::signature::{EcdsaKeyPair, Ed25519KeyPair, KeyPair, RsaKeyPair}; use snafu::ResultExt; @@ -16,34 +17,37 @@ use std::error::Error; /// This trait must be implemented for each type of key with which you will /// sign things. +#[async_trait] pub trait Sign: Sync + Send { /// Returns the decoded key along with its scheme and other metadata fn tuf_key(&self) -> crate::schema::key::Key; /// Signs the supplied message - fn sign( + async fn sign( &self, msg: &[u8], - rng: &dyn SecureRandom, + rng: &(dyn SecureRandom + Sync), ) -> std::result::Result, Box>; } /// Implements `Sign` for a reference to any type that implements `Sign`. +#[async_trait] impl<'a, T: Sign> Sign for &'a T { fn tuf_key(&self) -> Key { (*self).tuf_key() } - fn sign( + async fn sign( &self, msg: &[u8], - rng: &dyn SecureRandom, + rng: &(dyn SecureRandom + Sync), ) -> std::prelude::rust_2015::Result, Box> { - (*self).sign(msg, rng) + (*self).sign(msg, rng).await } } /// Implements the Sign trait for ED25519 +#[async_trait] impl Sign for Ed25519KeyPair { fn tuf_key(&self) -> Key { use crate::schema::key::{Ed25519Key, Ed25519Scheme}; @@ -58,10 +62,10 @@ impl Sign for Ed25519KeyPair { } } - fn sign( + async fn sign( &self, msg: &[u8], - _rng: &dyn SecureRandom, + _rng: &(dyn SecureRandom + Sync), ) -> std::result::Result, Box> { let signature = self.sign(msg); Ok(signature.as_ref().to_vec()) @@ -69,6 +73,7 @@ impl Sign for Ed25519KeyPair { } /// Implements the Sign trait for RSA keypairs +#[async_trait] impl Sign for RsaKeyPair { fn tuf_key(&self) -> Key { use crate::schema::key::{RsaKey, RsaScheme}; @@ -83,10 +88,10 @@ impl Sign for RsaKeyPair { } } - fn sign( + async fn sign( &self, msg: &[u8], - rng: &dyn SecureRandom, + rng: &(dyn SecureRandom + Sync), ) -> std::result::Result, Box> { let mut signature = vec![0; self.public_modulus_len()]; self.sign(&ring::signature::RSA_PSS_SHA256, rng, msg, &mut signature) @@ -96,6 +101,7 @@ impl Sign for RsaKeyPair { } /// Implements the Sign trait for ECDSA keypairs +#[async_trait] impl Sign for EcdsaKeyPair { fn tuf_key(&self) -> Key { use crate::schema::key::{EcdsaKey, EcdsaScheme}; @@ -110,10 +116,10 @@ impl Sign for EcdsaKeyPair { } } - fn sign( + async fn sign( &self, msg: &[u8], - rng: &dyn SecureRandom, + rng: &(dyn SecureRandom + Sync), ) -> std::result::Result, Box> { let signature = self.sign(rng, msg).context(error::SignSnafu)?; Ok(signature.as_ref().to_vec()) @@ -132,6 +138,7 @@ pub enum SignKeyPair { ECDSA(EcdsaKeyPair), } +#[async_trait] impl Sign for SignKeyPair { fn tuf_key(&self) -> Key { match self { @@ -141,15 +148,15 @@ impl Sign for SignKeyPair { } } - fn sign( + async fn sign( &self, msg: &[u8], - rng: &dyn SecureRandom, + rng: &(dyn SecureRandom + Sync), ) -> std::result::Result, Box> { match self { - RSA(key) => (key as &dyn Sign).sign(msg, rng), - ED25519(key) => (key as &dyn Sign).sign(msg, rng), - ECDSA(key) => (key as &dyn Sign).sign(msg, rng), + RSA(key) => (key as &dyn Sign).sign(msg, rng).await, + ED25519(key) => (key as &dyn Sign).sign(msg, rng).await, + ECDSA(key) => (key as &dyn Sign).sign(msg, rng).await, } } } diff --git a/tough/src/transport.rs b/tough/src/transport.rs index b75f1f59..f9c4f680 100644 --- a/tough/src/transport.rs +++ b/tough/src/transport.rs @@ -1,12 +1,39 @@ use crate::SafeUrlPath; #[cfg(feature = "http")] use crate::{HttpTransport, HttpTransportBuilder}; +use async_trait::async_trait; +use bytes::Bytes; use dyn_clone::DynClone; +use futures::{StreamExt, TryStreamExt}; +use futures_core::Stream; use std::error::Error; use std::fmt::{Debug, Display, Formatter}; -use std::io::{ErrorKind, Read}; +use std::io::{self, ErrorKind}; +use std::path::Path; +use std::pin::Pin; +use tokio_util::io::ReaderStream; use url::Url; +pub type TransportStream = Pin> + Send>>; + +/// Fallible byte streams that collect into a `Vec`. +#[async_trait] +pub trait IntoVec { + /// Try to collect into `Vec`. + async fn into_vec(self) -> Result, E>; +} + +#[async_trait] +impl> + Send, E: Send> IntoVec for S { + async fn into_vec(self) -> Result, E> { + self.try_fold(Vec::new(), |mut acc, bytes| { + acc.extend(bytes.as_ref()); + std::future::ready(Ok(acc)) + }) + .await + } +} + /// A trait to abstract over the method/protocol by which files are obtained. /// /// The trait hides the underlying types involved by returning the `Read` object as a @@ -14,9 +41,10 @@ use url::Url; /// /// Inclusion of the `DynClone` trait means that you will need to implement `Clone` when /// implementing a `Transport`. -pub trait Transport: Debug + DynClone { +#[async_trait] +pub trait Transport: Debug + DynClone + Send + Sync { /// Opens a `Read` object for the file specified by `url`. - fn fetch(&self, url: Url) -> Result, TransportError>; + async fn fetch(&self, url: Url) -> Result; } // Implements `Clone` for `Transport` trait objects (i.e. on `Box::`). To facilitate @@ -26,7 +54,7 @@ dyn_clone::clone_trait_object!(Transport); // =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= /// The kind of error that the transport object experienced during `fetch`. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] #[non_exhaustive] pub enum TransportErrorKind { /// The [`Transport`] does not handle the URL scheme. e.g. `file://` or `http://`. @@ -139,8 +167,24 @@ impl Error for TransportError { #[derive(Debug, Clone, Copy)] pub struct FilesystemTransport; +impl FilesystemTransport { + async fn open( + file_path: impl AsRef, + ) -> Result> + Send, io::Error> { + // Open the file + let f = tokio::fs::File::open(file_path).await?; + + // And convert to stream + let reader = tokio::io::BufReader::new(f); + let stream = ReaderStream::new(reader); + + Ok(stream) + } +} + +#[async_trait] impl Transport for FilesystemTransport { - fn fetch(&self, url: Url) -> Result, TransportError> { + async fn fetch(&self, url: Url) -> Result { // If the scheme isn't "file://", reject if url.scheme() != "file" { return Err(TransportError::new( @@ -151,15 +195,21 @@ impl Transport for FilesystemTransport { let file_path = url.safe_url_filepath(); - // And open the file - let f = std::fs::File::open(file_path).map_err(|e| { + // Open the file + let stream = Self::open(file_path).await; + + // And map to `TransportError` + let map_io_err = move |e: io::Error| -> TransportError { let kind = match e.kind() { ErrorKind::NotFound => TransportErrorKind::FileNotFound, _ => TransportErrorKind::Other, }; - TransportError::new_with_cause(kind, url, e) - })?; - Ok(Box::new(f)) + TransportError::new_with_cause(kind, url.clone(), e) + }; + Ok(stream + .map_err(map_io_err.clone())? + .map_err(map_io_err) + .boxed()) } } @@ -202,11 +252,12 @@ impl DefaultTransport { } } +#[async_trait] impl Transport for DefaultTransport { - fn fetch(&self, url: Url) -> Result, TransportError> { + async fn fetch(&self, url: Url) -> Result { match url.scheme() { - "file" => self.file.fetch(url), - "http" | "https" => self.handle_http(url), + "file" => self.file.fetch(url).await, + "http" | "https" => self.handle_http(url).await, _ => Err(TransportError::new( TransportErrorKind::UnsupportedUrlScheme, url, @@ -218,7 +269,7 @@ impl Transport for DefaultTransport { impl DefaultTransport { #[cfg(not(feature = "http"))] #[allow(clippy::trivially_copy_pass_by_ref, clippy::unused_self)] - fn handle_http(&self, url: Url) -> Result, TransportError> { + async fn handle_http(&self, url: Url) -> Result { Err(TransportError::new_with_cause( TransportErrorKind::UnsupportedUrlScheme, url, @@ -227,7 +278,7 @@ impl DefaultTransport { } #[cfg(feature = "http")] - fn handle_http(&self, url: Url) -> Result, TransportError> { - self.http.fetch(url) + async fn handle_http(&self, url: Url) -> Result { + self.http.fetch(url).await } } diff --git a/tough/tests/expiration_enforcement.rs b/tough/tests/expiration_enforcement.rs index 52c36091..d2e57d44 100644 --- a/tough/tests/expiration_enforcement.rs +++ b/tough/tests/expiration_enforcement.rs @@ -1,7 +1,6 @@ // Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: MIT OR Apache-2.0 -use std::fs::File; use test_utils::{dir_url, test_data}; use tough::error::Error::ExpiredMetadata; use tough::schema::RoleType; @@ -11,16 +10,19 @@ mod test_utils; /// Test that `tough` fails to load an expired repository when `expiration_enforcement` is `Safe`. /// -#[test] -fn test_expiration_enforcement_safe() { +#[tokio::test] +async fn test_expiration_enforcement_safe() { let base = test_data().join("expired-repository"); let result = RepositoryLoader::new( - File::open(base.join("metadata").join("1.root.json")).unwrap(), + &tokio::fs::read(base.join("metadata").join("1.root.json")) + .await + .unwrap(), dir_url(base.join("metadata")), dir_url(base.join("targets")), ) - .load(); + .load() + .await; if let Err(err) = result { match err { ExpiredMetadata { role, backtrace: _ } => { @@ -39,15 +41,18 @@ fn test_expiration_enforcement_safe() { /// Test that `tough` loads an expired repository when `expiration_enforcement` is `Unsafe`. /// -#[test] -fn test_expiration_enforcement_unsafe() { +#[tokio::test] +async fn test_expiration_enforcement_unsafe() { let base = test_data().join("expired-repository"); let result = RepositoryLoader::new( - File::open(base.join("metadata").join("1.root.json")).unwrap(), + &tokio::fs::read(base.join("metadata").join("1.root.json")) + .await + .unwrap(), dir_url(base.join("metadata")), dir_url(base.join("targets")), ) .expiration_enforcement(ExpirationEnforcement::Unsafe) - .load(); + .load() + .await; assert!(result.is_ok()) } diff --git a/tough/tests/http.rs b/tough/tests/http.rs index 1f83d055..3b2de5c3 100644 --- a/tough/tests/http.rs +++ b/tough/tests/http.rs @@ -5,15 +5,14 @@ mod test_utils; mod http_happy { use crate::test_utils::{read_to_end, test_data}; use httptest::{matchers::*, responders::*, Expectation, Server}; - use std::fs::File; use std::str::FromStr; use tough::{DefaultTransport, HttpTransport, RepositoryLoader, TargetName, Transport}; use url::Url; /// Set an expectation in a test HTTP server which serves a file from `tuf-reference-impl`. - fn create_successful_get(relative_path: &str) -> httptest::Expectation { + async fn create_successful_get(relative_path: &str) -> httptest::Expectation { let repo_dir = test_data().join("tuf-reference-impl"); - let file_bytes = std::fs::read(repo_dir.join(relative_path)).unwrap(); + let file_bytes = tokio::fs::read(repo_dir.join(relative_path)).await.unwrap(); Expectation::matching(request::method_path("GET", format!("/{}", relative_path))) .times(1) .respond_with( @@ -34,47 +33,50 @@ mod http_happy { } /// Test that `tough` works with a healthy HTTP server. - #[test] - fn test_http_transport_happy_case() { - run_http_test(HttpTransport::default()); + #[tokio::test] + async fn test_http_transport_happy_case() { + run_http_test(HttpTransport::default()).await; } /// Test that `DefaultTransport` works over HTTP when the `http` feature is enabled. - #[test] - fn test_http_default_transport() { - run_http_test(DefaultTransport::default()); + #[tokio::test] + async fn test_http_default_transport() { + run_http_test(DefaultTransport::default()).await; } - fn run_http_test(transport: T) { + async fn run_http_test(transport: T) { let server = Server::run(); let repo_dir = test_data().join("tuf-reference-impl"); - server.expect(create_successful_get("metadata/timestamp.json")); - server.expect(create_successful_get("metadata/snapshot.json")); - server.expect(create_successful_get("metadata/targets.json")); - server.expect(create_successful_get("metadata/role1.json")); - server.expect(create_successful_get("metadata/role2.json")); - server.expect(create_successful_get("targets/file1.txt")); - server.expect(create_successful_get("targets/file2.txt")); + server.expect(create_successful_get("metadata/timestamp.json").await); + server.expect(create_successful_get("metadata/snapshot.json").await); + server.expect(create_successful_get("metadata/targets.json").await); + server.expect(create_successful_get("metadata/role1.json").await); + server.expect(create_successful_get("metadata/role2.json").await); + server.expect(create_successful_get("targets/file1.txt").await); + server.expect(create_successful_get("targets/file2.txt").await); server.expect(create_unsuccessful_get("metadata/2.root.json")); let metadata_base_url = Url::from_str(server.url_str("/metadata").as_str()).unwrap(); let targets_base_url = Url::from_str(server.url_str("/targets").as_str()).unwrap(); let repo = RepositoryLoader::new( - File::open(repo_dir.join("metadata").join("1.root.json")).unwrap(), + &tokio::fs::read(repo_dir.join("metadata").join("1.root.json")) + .await + .unwrap(), metadata_base_url, targets_base_url, ) .transport(transport) .load() + .await .unwrap(); let file1 = TargetName::new("file1.txt").unwrap(); assert_eq!( - read_to_end(repo.read_target(&file1).unwrap().unwrap()), + read_to_end(repo.read_target(&file1).await.unwrap().unwrap()).await, &b"This is an example target file."[..] ); let file2 = TargetName::new("file2.txt").unwrap(); assert_eq!( - read_to_end(repo.read_target(&file2).unwrap().unwrap()), + read_to_end(repo.read_target(&file2).await.unwrap().unwrap()).await, &b"This is an another example target file."[..] ); assert_eq!( @@ -96,7 +98,6 @@ mod http_happy { mod http_integ { use crate::test_utils::test_data; use failure_server::IntegServers; - use std::fs::File; use std::path::PathBuf; use tough::{HttpTransportBuilder, RepositoryLoader}; use url::Url; @@ -131,34 +132,29 @@ mod http_integ { .expect("Failed to run integration test HTTP servers"); // Load the tuf-reference-impl repo via http repeatedly through faulty proxies. - // We avoid nested tokio runtimes from `reqwest::blocking` by sequestering it to another - // thread in a blocking task. - tokio::task::spawn_blocking(move || { - for i in 0..5 { - let transport = HttpTransportBuilder::new() - // the service we have created is very toxic with many failures, so we will do a - // large number of retries, enough that we can be reasonably assured that we - // will always succeed. - .tries(200) - // we don't want the test to take forever so we use small pauses - .initial_backoff(std::time::Duration::from_nanos(100)) - .max_backoff(std::time::Duration::from_millis(1)) - .build(); - let root_path = tuf_reference_impl_root_json(); + for i in 0..5 { + let transport = HttpTransportBuilder::new() + // the service we have created is very toxic with many failures, so we will do a + // large number of retries, enough that we can be reasonably assured that we + // will always succeed. + .tries(200) + // we don't want the test to take forever so we use small pauses + .initial_backoff(std::time::Duration::from_nanos(100)) + .max_backoff(std::time::Duration::from_millis(1)) + .build(); + let root_path = tuf_reference_impl_root_json(); - RepositoryLoader::new( - File::open(&root_path).unwrap(), - Url::parse("http://localhost:10102/metadata").unwrap(), - Url::parse("http://localhost:10102/targets").unwrap(), - ) - .transport(transport) - .load() - .unwrap(); - println!("{}:{} SUCCESSFULLY LOADED THE REPO {}", file!(), line!(), i,); - } - }) - .await - .expect("Failed to load the repo through faulty proxies"); + RepositoryLoader::new( + &tokio::fs::read(&root_path).await.unwrap(), + Url::parse("http://localhost:10102/metadata").unwrap(), + Url::parse("http://localhost:10102/targets").unwrap(), + ) + .transport(transport) + .load() + .await + .unwrap(); + println!("{}:{} SUCCESSFULLY LOADED THE REPO {}", file!(), line!(), i,); + } integ_servers .teardown() diff --git a/tough/tests/interop.rs b/tough/tests/interop.rs index bd6e0e8f..a60b5a35 100644 --- a/tough/tests/interop.rs +++ b/tough/tests/interop.rs @@ -1,7 +1,6 @@ // Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: MIT OR Apache-2.0 -use std::fs::File; use tempfile::TempDir; use test_utils::{dir_url, read_to_end, test_data}; use tough::{FilesystemTransport, Limits, Repository, RepositoryLoader, TargetName}; @@ -12,30 +11,33 @@ mod test_utils; /// implementation using the `load_default` function. /// /// [`tuf`]: https://github.com/theupdateframework/tuf -#[test] -fn test_tuf_reference_impl() { +#[tokio::test] +async fn test_tuf_reference_impl() { let base = test_data().join("tuf-reference-impl"); let repo = RepositoryLoader::new( - File::open(base.join("metadata").join("1.root.json")).unwrap(), + &tokio::fs::read(base.join("metadata").join("1.root.json")) + .await + .unwrap(), dir_url(base.join("metadata")), dir_url(base.join("targets")), ) .load() + .await .unwrap(); - assert_tuf_reference_impl(&repo); + assert_tuf_reference_impl(&repo).await; } -fn assert_tuf_reference_impl(repo: &Repository) { +async fn assert_tuf_reference_impl(repo: &Repository) { let file1 = TargetName::new("file1.txt").unwrap(); let file2 = TargetName::new("file2.txt").unwrap(); let file3 = TargetName::new("file3.txt").unwrap(); assert_eq!( - read_to_end(repo.read_target(&file1).unwrap().unwrap()), + read_to_end(repo.read_target(&file1).await.unwrap().unwrap()).await, &b"This is an example target file."[..] ); assert_eq!( - read_to_end(repo.read_target(&file2).unwrap().unwrap()), + read_to_end(repo.read_target(&file2).await.unwrap().unwrap()).await, &b"This is an another example target file."[..] ); assert_eq!( @@ -61,13 +63,15 @@ fn assert_tuf_reference_impl(repo: &Repository) { /// Test that `tough` can process repositories generated by [`tuf`], the reference Python /// implementation using the `load` function with non-default [`Options`]. -#[test] -fn test_tuf_reference_impl_default_transport() { +#[tokio::test] +async fn test_tuf_reference_impl_default_transport() { let base = test_data().join("tuf-reference-impl"); let datastore = TempDir::new().unwrap(); let repo = RepositoryLoader::new( - File::open(base.join("metadata").join("1.root.json")).unwrap(), + &tokio::fs::read(base.join("metadata").join("1.root.json")) + .await + .unwrap(), dir_url(base.join("metadata")), dir_url(base.join("targets")), ) @@ -80,25 +84,29 @@ fn test_tuf_reference_impl_default_transport() { }) .datastore(datastore.path()) .load() + .await .unwrap(); - assert_tuf_reference_impl(&repo); + assert_tuf_reference_impl(&repo).await; } /// Test that `tough` can load a repository that has some unusual delegate role names. This ensures /// that percent encoded role names are handled correctly and that path traversal characters in a /// role name do not cause `tough` to write outside of its datastore. -#[test] -fn test_dubious_role_name() { +#[tokio::test] +async fn test_dubious_role_name() { let base = test_data().join("dubious-role-names"); let datastore = TempDir::new().unwrap(); let repo = RepositoryLoader::new( - File::open(base.join("metadata").join("1.root.json")).unwrap(), + &tokio::fs::read(base.join("metadata").join("1.root.json")) + .await + .unwrap(), dir_url(base.join("metadata")), dir_url(base.join("targets")), ) .datastore(datastore.path()) .load() + .await .unwrap(); // Prove that the role name has path traversal characters. diff --git a/tough/tests/repo_cache.rs b/tough/tests/repo_cache.rs index 7d0cae2d..034ec509 100644 --- a/tough/tests/repo_cache.rs +++ b/tough/tests/repo_cache.rs @@ -1,8 +1,6 @@ // Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: MIT OR Apache-2.0 -use std::fs::File; -use std::io::Read; use std::path::PathBuf; use tempfile::TempDir; use test_utils::{dir_url, read_to_end, test_data, DATA_1, DATA_2}; @@ -27,27 +25,28 @@ impl RepoPaths { } } - fn root(&self) -> File { - File::open(&self.root_path).unwrap() + async fn root(&self) -> Vec { + tokio::fs::read(&self.root_path).await.unwrap() } } -fn load_tuf_reference_impl(paths: &RepoPaths) -> Repository { +async fn load_tuf_reference_impl(paths: &RepoPaths) -> Repository { RepositoryLoader::new( - &mut paths.root(), + &paths.root().await, paths.metadata_base_url.clone(), paths.targets_base_url.clone(), ) .load() + .await .unwrap() } /// Test that the repo.cache() function works when given a list of multiple targets. -#[test] -fn test_repo_cache_all_targets() { +#[tokio::test] +async fn test_repo_cache_all_targets() { // load the reference_impl repo let repo_paths = RepoPaths::new(); - let repo = load_tuf_reference_impl(&repo_paths); + let repo = load_tuf_reference_impl(&repo_paths).await; // cache the repo for future use let destination = TempDir::new().unwrap(); @@ -59,45 +58,35 @@ fn test_repo_cache_all_targets() { None::<&[&str]>, true, ) + .await .unwrap(); // check that we can load the copied repo. let copied_repo = RepositoryLoader::new( - repo_paths.root(), + &repo_paths.root().await, dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); // the copied repo should have file1 and file2 (i.e. all of targets). - let mut file_data = Vec::new(); let file1 = TargetName::new("file1.txt").unwrap(); - let file_size = copied_repo - .read_target(&file1) - .unwrap() - .unwrap() - .read_to_end(&mut file_data) - .unwrap(); - assert_eq!(31, file_size); + let file_data = read_to_end(copied_repo.read_target(&file1).await.unwrap().unwrap()).await; + assert_eq!(31, file_data.len()); - let mut file_data = Vec::new(); let file2 = TargetName::new("file2.txt").unwrap(); - let file_size = copied_repo - .read_target(&file2) - .unwrap() - .unwrap() - .read_to_end(&mut file_data) - .unwrap(); - assert_eq!(39, file_size); + let file_data = read_to_end(copied_repo.read_target(&file2).await.unwrap().unwrap()).await; + assert_eq!(39, file_data.len()); } /// Test that the repo.cache() function works when given a list of multiple targets. -#[test] -fn test_repo_cache_list_of_two_targets() { +#[tokio::test] +async fn test_repo_cache_list_of_two_targets() { // load the reference_impl repo let repo_paths = RepoPaths::new(); - let repo = load_tuf_reference_impl(&repo_paths); + let repo = load_tuf_reference_impl(&repo_paths).await; // cache the repo for future use let destination = TempDir::new().unwrap(); @@ -110,45 +99,35 @@ fn test_repo_cache_list_of_two_targets() { Some(&targets_subset), true, ) + .await .unwrap(); // check that we can load the copied repo. let copied_repo = RepositoryLoader::new( - repo_paths.root(), + &repo_paths.root().await, dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); // the copied repo should have file1 and file2 (i.e. all of the listed targets). - let mut file_data = Vec::new(); let file1 = TargetName::new("file1.txt").unwrap(); - let file_size = copied_repo - .read_target(&file1) - .unwrap() - .unwrap() - .read_to_end(&mut file_data) - .unwrap(); - assert_eq!(31, file_size); + let file_data = read_to_end(copied_repo.read_target(&file1).await.unwrap().unwrap()).await; + assert_eq!(31, file_data.len()); - let mut file_data = Vec::new(); let file2 = TargetName::new("file2.txt").unwrap(); - let file_size = copied_repo - .read_target(&file2) - .unwrap() - .unwrap() - .read_to_end(&mut file_data) - .unwrap(); - assert_eq!(39, file_size); + let file_data = read_to_end(copied_repo.read_target(&file2).await.unwrap().unwrap()).await; + assert_eq!(39, file_data.len()); } /// Test that the repo.cache() function works when given a list of only one of the targets. -#[test] -fn test_repo_cache_some() { +#[tokio::test] +async fn test_repo_cache_some() { // load the reference_impl repo let repo_paths = RepoPaths::new(); - let repo = load_tuf_reference_impl(&repo_paths); + let repo = load_tuf_reference_impl(&repo_paths).await; // cache the repo for future use let destination = TempDir::new().unwrap(); @@ -161,58 +140,57 @@ fn test_repo_cache_some() { Some(&targets_subset), true, ) + .await .unwrap(); // check that we can load the copied repo. let copied_repo = RepositoryLoader::new( - repo_paths.root(), + &repo_paths.root().await, dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); // the copied repo should have file2 but not file1 (i.e. only the listed targets). let file1 = TargetName::new("file1.txt").unwrap(); - let read_target_result = copied_repo.read_target(&file1); + let read_target_result = copied_repo.read_target(&file1).await; assert!(read_target_result.is_err()); - let mut file_data = Vec::new(); let file2 = TargetName::new("file2.txt").unwrap(); - let file_size = copied_repo - .read_target(&file2) - .unwrap() - .unwrap() - .read_to_end(&mut file_data) - .unwrap(); - assert_eq!(39, file_size); + let file_data = read_to_end(copied_repo.read_target(&file2).await.unwrap().unwrap()).await; + assert_eq!(39, file_data.len()); } -#[test] -fn test_repo_cache_metadata() { +#[tokio::test] +async fn test_repo_cache_metadata() { // Load the reference_impl repo let repo_paths = RepoPaths::new(); - let repo = load_tuf_reference_impl(&repo_paths); + let repo = load_tuf_reference_impl(&repo_paths).await; // Cache the repo for future use let destination = TempDir::new().unwrap(); let metadata_destination = destination.as_ref().join("metadata"); - repo.cache_metadata(&metadata_destination, true).unwrap(); + repo.cache_metadata(&metadata_destination, true) + .await + .unwrap(); // Load the copied repo - this validates we cached the metadata (if we didn't we couldn't load // the repo) let targets_destination = destination.as_ref().join("targets"); let copied_repo = RepositoryLoader::new( - repo_paths.root(), + &repo_paths.root().await, dir_url(&metadata_destination), dir_url(targets_destination), ) .load() + .await .unwrap(); // Validate we didn't cache any targets for (target_name, _) in copied_repo.targets().signed.targets_map() { - assert!(copied_repo.read_target(&target_name).is_err()) + assert!(copied_repo.read_target(&target_name).await.is_err()) } // Verify we also loaded the delegated role "role1" @@ -223,34 +201,37 @@ fn test_repo_cache_metadata() { assert!(metadata_destination.join("1.root.json").exists()); } -#[test] -fn test_repo_cache_metadata_no_root_chain() { +#[tokio::test] +async fn test_repo_cache_metadata_no_root_chain() { // Load the reference_impl repo let repo_paths = RepoPaths::new(); - let repo = load_tuf_reference_impl(&repo_paths); + let repo = load_tuf_reference_impl(&repo_paths).await; // Cache the repo for future use let destination = TempDir::new().unwrap(); let metadata_destination = destination.as_ref().join("metadata"); - repo.cache_metadata(&metadata_destination, false).unwrap(); + repo.cache_metadata(&metadata_destination, false) + .await + .unwrap(); // Verify we did not cache the root.json assert!(!metadata_destination.join("1.root.json").exists()); } /// Test that the repo.cache() function prepends target names with sha digest. -#[test] -fn test_repo_cache_consistent_snapshots() { +#[tokio::test] +async fn test_repo_cache_consistent_snapshots() { let repo_name = "consistent-snapshots"; let metadata_dir = test_data().join(repo_name).join("metadata"); let targets_dir = test_data().join(repo_name).join("targets"); let root = metadata_dir.join("1.root.json"); let repo = RepositoryLoader::new( - File::open(&root).unwrap(), + &tokio::fs::read(&root).await.unwrap(), dir_url(metadata_dir), dir_url(targets_dir), ) .load() + .await .unwrap(); // cache the repo for future use @@ -264,33 +245,43 @@ fn test_repo_cache_consistent_snapshots() { Option::<&[&str]>::None, true, ) + .await .unwrap(); // check that we can load the copied repo. let copied_repo = RepositoryLoader::new( - File::open(&root).unwrap(), + &tokio::fs::read(&root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); // the copied repo should have file2 but not file1 (i.e. only the listed targets). - let data1 = String::from_utf8(read_to_end( - copied_repo - .read_target(&TargetName::new("data1.txt").unwrap()) - .unwrap() - .unwrap(), - )) + let data1 = String::from_utf8( + read_to_end( + copied_repo + .read_target(&TargetName::new("data1.txt").unwrap()) + .await + .unwrap() + .unwrap(), + ) + .await, + ) .unwrap(); assert_eq!(data1, DATA_1); - let data2 = String::from_utf8(read_to_end( - copied_repo - .read_target(&TargetName::new("data2.txt").unwrap()) - .unwrap() - .unwrap(), - )) + let data2 = String::from_utf8( + read_to_end( + copied_repo + .read_target(&TargetName::new("data2.txt").unwrap()) + .await + .unwrap() + .unwrap(), + ) + .await, + ) .unwrap(); assert_eq!(data2, DATA_2); diff --git a/tough/tests/repo_editor.rs b/tough/tests/repo_editor.rs index 4d6a172a..0eec36e6 100644 --- a/tough/tests/repo_editor.rs +++ b/tough/tests/repo_editor.rs @@ -4,11 +4,11 @@ use crate::test_utils::{dir_url, read_to_end, test_data}; use chrono::{Duration, Utc}; use std::collections::HashMap; -use std::fs::File; -use std::io::prelude::Write; use std::num::NonZeroU64; use std::path::PathBuf; use tempfile::TempDir; +use tokio::fs::File; +use tokio::io::AsyncWriteExt; use tough::editor::signed::PathExists; use tough::editor::{targets::TargetsEditor, RepositoryEditor}; use tough::key_source::KeySource; @@ -38,8 +38,8 @@ impl RepoPaths { } } - fn root(&self) -> File { - File::open(&self.root_path).unwrap() + async fn root(&self) -> Vec { + tokio::fs::read(&self.root_path).await.unwrap() } } @@ -65,17 +65,18 @@ fn targets_path() -> PathBuf { test_data().join("tuf-reference-impl").join("targets") } -fn load_tuf_reference_impl(paths: &mut RepoPaths) -> Repository { +async fn load_tuf_reference_impl(paths: &mut RepoPaths) -> Repository { RepositoryLoader::new( - paths.root(), + &paths.root().await, paths.metadata_base_url.clone(), paths.targets_base_url.clone(), ) .load() + .await .unwrap() } -fn test_repo_editor() -> RepositoryEditor { +async fn test_repo_editor() -> RepositoryEditor { let root = root_path(); let timestamp_expiration = Utc::now().checked_add_signed(Duration::days(3)).unwrap(); let timestamp_version = NonZeroU64::new(1234).unwrap(); @@ -86,7 +87,7 @@ fn test_repo_editor() -> RepositoryEditor { let target3 = targets_path().join("file3.txt"); let target_list = vec![target3]; - let mut editor = RepositoryEditor::new(root).unwrap(); + let mut editor = RepositoryEditor::new(root).await.unwrap(); editor .targets_expires(targets_expiration) .unwrap() @@ -97,33 +98,34 @@ fn test_repo_editor() -> RepositoryEditor { .timestamp_expires(timestamp_expiration) .timestamp_version(timestamp_version) .add_target_paths(target_list) + .await .unwrap(); editor } -fn key_hash_map(keys: &[Box]) -> HashMap, Key> { +async fn key_hash_map(keys: &[Box]) -> HashMap, Key> { let mut key_pairs = HashMap::new(); for source in keys { - let key_pair = source.as_sign().unwrap().tuf_key(); + let key_pair = source.as_sign().await.unwrap().tuf_key(); key_pairs.insert(key_pair.key_id().unwrap().clone(), key_pair.clone()); } key_pairs } // Test a RepositoryEditor can be created from an existing Repo -#[test] -fn repository_editor_from_repository() { +#[tokio::test] +async fn repository_editor_from_repository() { // Load the reference_impl repo let mut repo_paths = RepoPaths::new(); let root = repo_paths.root_path.clone(); - let repo = load_tuf_reference_impl(&mut repo_paths); + let repo = load_tuf_reference_impl(&mut repo_paths).await; - assert!(RepositoryEditor::from_repo(root, repo).is_ok()); + assert!(RepositoryEditor::from_repo(root, repo).await.is_ok()); } // Create sign write and reload repo -#[test] -fn create_sign_write_reload_repo() { +#[tokio::test] +async fn create_sign_write_reload_repo() { let root = root_path(); let timestamp_expiration = Utc::now().checked_add_signed(Duration::days(3)).unwrap(); let timestamp_version = NonZeroU64::new(1234).unwrap(); @@ -136,7 +138,7 @@ fn create_sign_write_reload_repo() { let create_dir = TempDir::new().unwrap(); - let mut editor = RepositoryEditor::new(&root).unwrap(); + let mut editor = RepositoryEditor::new(&root).await.unwrap(); editor .targets_expires(targets_expiration) .unwrap() @@ -147,6 +149,7 @@ fn create_sign_write_reload_repo() { .timestamp_expires(timestamp_expiration) .timestamp_version(timestamp_version) .add_target_paths(target_list) + .await .unwrap(); let targets_key: &[std::boxed::Box<(dyn tough::key_source::KeySource + 'static)>] = @@ -170,14 +173,17 @@ fn create_sign_write_reload_repo() { Utc::now().checked_add_signed(Duration::days(21)).unwrap(), NonZeroU64::new(1).unwrap(), ) + .await .unwrap(); // switch repo owner to role1 editor .sign_targets_editor(targets_key) + .await .unwrap() .change_delegated_targets("role1") .unwrap() .add_target_paths([targets_path().join("file1.txt").to_str().unwrap()].to_vec()) + .await .unwrap() .delegate_role( "role2", @@ -187,6 +193,7 @@ fn create_sign_write_reload_repo() { Utc::now().checked_add_signed(Duration::days(21)).unwrap(), NonZeroU64::new(1).unwrap(), ) + .await .unwrap() .delegate_role( "role3", @@ -196,6 +203,7 @@ fn create_sign_write_reload_repo() { Utc::now().checked_add_signed(Duration::days(21)).unwrap(), NonZeroU64::new(1).unwrap(), ) + .await .unwrap(); editor .targets_version(targets_version) @@ -203,6 +211,7 @@ fn create_sign_write_reload_repo() { .targets_expires(targets_expiration) .unwrap() .sign_targets_editor(role1_key) + .await .unwrap() .change_delegated_targets("targets") .unwrap() @@ -214,35 +223,38 @@ fn create_sign_write_reload_repo() { Utc::now().checked_add_signed(Duration::days(21)).unwrap(), NonZeroU64::new(1).unwrap(), ) + .await .unwrap() .targets_version(targets_version) .unwrap() .targets_expires(targets_expiration) .unwrap(); - let signed_repo = editor.sign(targets_key).unwrap(); + let signed_repo = editor.sign(targets_key).await.unwrap(); let metadata_destination = create_dir.path().join("metadata"); let targets_destination = create_dir.path().join("targets"); - assert!(signed_repo.write(&metadata_destination).is_ok()); + assert!(signed_repo.write(&metadata_destination).await.is_ok()); assert!(signed_repo .link_targets(targets_path(), &targets_destination, PathExists::Skip) + .await .is_ok()); // Load the repo we just created let _new_repo = RepositoryLoader::new( - File::open(&root).unwrap(), + &tokio::fs::read(&root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); } -#[test] +#[tokio::test] /// Delegates role from Targets to A and then A to B -fn create_role_flow() { - let editor = test_repo_editor(); +async fn create_role_flow() { + let editor = test_repo_editor().await; let targets_key: &[std::boxed::Box<(dyn tough::key_source::KeySource + 'static)>] = &[Box::new(LocalKeySource { path: key_path() })]; @@ -259,43 +271,51 @@ fn create_role_flow() { let repodir = TempDir::new().unwrap(); let metadata_destination = repodir.as_ref().join("metadata"); let targets_destination = repodir.as_ref().join("targets"); - let signed = editor.sign(targets_key).unwrap(); - signed.write(&metadata_destination).unwrap(); + let signed = editor.sign(targets_key).await.unwrap(); + signed.write(&metadata_destination).await.unwrap(); // create new delegated target as "A" and sign with role1_key let new_role = TargetsEditor::new("A") .version(NonZeroU64::new(1).unwrap()) .expires(Utc::now().checked_add_signed(Duration::days(21)).unwrap()) .sign(role1_key) + .await .unwrap(); // write the role to outdir let outdir = TempDir::new().unwrap(); let metadata_destination_out = outdir.as_ref().join("metadata"); - new_role.write(&metadata_destination_out, false).unwrap(); + new_role + .write(&metadata_destination_out, false) + .await + .unwrap(); // reload repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(targets_destination), ) .load() + .await .unwrap(); let metadata_base_url_out = dir_url(&metadata_destination_out); // add outdir to repo //create a new editor with the repo - let mut editor = RepositoryEditor::from_repo(root_path(), new_repo).unwrap(); + let mut editor = RepositoryEditor::from_repo(root_path(), new_repo) + .await + .unwrap(); editor .add_role( "A", metadata_base_url_out.as_str(), PathSet::Paths(vec![PathPattern::new("*.txt").unwrap()]), NonZeroU64::new(1).unwrap(), - Some(key_hash_map(role1_key)), + Some(key_hash_map(role1_key).await), ) + .await .unwrap(); //sign everything since targets key is the same as snapshot and timestamp @@ -316,23 +336,24 @@ fn create_role_flow() { .snapshot_version(snapshot_version) .timestamp_expires(timestamp_expiration) .timestamp_version(timestamp_version); - let signed = editor.sign(&[Box::new(key_source)]).unwrap(); + let signed = editor.sign(&[Box::new(key_source)]).await.unwrap(); // write repo let new_dir = TempDir::new().unwrap(); let metadata_destination = new_dir.as_ref().join("metadata"); let targets_destination = new_dir.as_ref().join("targets"); - signed.write(&metadata_destination).unwrap(); + signed.write(&metadata_destination).await.unwrap(); // reload repo and verify that A role is included // reload repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); new_repo.delegated_role("A").unwrap(); @@ -343,20 +364,25 @@ fn create_role_flow() { .version(NonZeroU64::new(1).unwrap()) .expires(Utc::now().checked_add_signed(Duration::days(21)).unwrap()) .sign(role2_key) + .await .unwrap(); // write the role to outdir let outdir = TempDir::new().unwrap(); let metadata_destination_out = outdir.as_ref().join("metadata"); - new_role.write(&metadata_destination_out, false).unwrap(); + new_role + .write(&metadata_destination_out, false) + .await + .unwrap(); // reload repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); let metadata_base_url_out = dir_url(&metadata_destination_out); @@ -371,31 +397,34 @@ fn create_role_flow() { metadata_base_url_out.as_str(), PathSet::Paths(vec![PathPattern::new("file?.txt").unwrap()]), NonZeroU64::new(1).unwrap(), - Some(key_hash_map(role2_key)), + Some(key_hash_map(role2_key).await), ) + .await .unwrap() .version(NonZeroU64::new(1).unwrap()) .expires(Utc::now().checked_add_signed(Duration::days(21)).unwrap()); // sign A and write A and B metadata to output directory - let signed_roles = editor.sign(role1_key).unwrap(); + let signed_roles = editor.sign(role1_key).await.unwrap(); // write the role to outdir let outdir = TempDir::new().unwrap(); let metadata_destination_out = outdir.as_ref().join("metadata"); signed_roles .write(&metadata_destination_out, false) + .await .unwrap(); // reload repo and add in A and B metadata and update snapshot // reload repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); let metadata_base_url_out = dir_url(&metadata_destination_out); @@ -403,9 +432,12 @@ fn create_role_flow() { let root_key = key_path(); let key_source = LocalKeySource { path: root_key }; - let mut editor = RepositoryEditor::from_repo(root_path(), new_repo).unwrap(); + let mut editor = RepositoryEditor::from_repo(root_path(), new_repo) + .await + .unwrap(); editor .update_delegated_targets("A", metadata_base_url_out.as_str()) + .await .unwrap(); editor .snapshot_version(NonZeroU64::new(1).unwrap()) @@ -413,7 +445,7 @@ fn create_role_flow() { .timestamp_version(NonZeroU64::new(1).unwrap()) .timestamp_expires(Utc::now().checked_add_signed(Duration::days(21)).unwrap()); - let signed_refreshed_repo = editor.sign(&[Box::new(key_source)]).unwrap(); + let signed_refreshed_repo = editor.sign(&[Box::new(key_source)]).await.unwrap(); // write repo let end_repo = TempDir::new().unwrap(); @@ -421,16 +453,20 @@ fn create_role_flow() { let metadata_destination = end_repo.as_ref().join("metadata"); let targets_destination = end_repo.as_ref().join("targets"); - signed_refreshed_repo.write(&metadata_destination).unwrap(); + signed_refreshed_repo + .write(&metadata_destination) + .await + .unwrap(); // reload repo and verify that A and B role are included let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(metadata_destination), dir_url(targets_destination), ) .load() + .await .unwrap(); // verify that role A and B are included @@ -438,11 +474,11 @@ fn create_role_flow() { new_repo.delegated_role("B").unwrap(); } -#[test] +#[tokio::test] /// Delegtes role from Targets to A and then A to B -fn update_targets_flow() { +async fn update_targets_flow() { // The beginning of this creates a repo with Target -> A ('*.txt') -> B ('file?.txt') - let editor = test_repo_editor(); + let editor = test_repo_editor().await; let targets_key: &[std::boxed::Box<(dyn tough::key_source::KeySource + 'static)>] = &[Box::new(LocalKeySource { path: key_path() })]; @@ -459,43 +495,51 @@ fn update_targets_flow() { let repodir = TempDir::new().unwrap(); let metadata_destination = repodir.as_ref().join("metadata"); let targets_destination = repodir.as_ref().join("targets"); - let signed = editor.sign(targets_key).unwrap(); - signed.write(&metadata_destination).unwrap(); + let signed = editor.sign(targets_key).await.unwrap(); + signed.write(&metadata_destination).await.unwrap(); // create new delegated target as "A" and sign with role1_key let new_role = TargetsEditor::new("A") .version(NonZeroU64::new(1).unwrap()) .expires(Utc::now().checked_add_signed(Duration::days(21)).unwrap()) .sign(role1_key) + .await .unwrap(); // write the role to outdir let outdir = TempDir::new().unwrap(); let metadata_destination_out = outdir.as_ref().join("metadata"); - new_role.write(&metadata_destination_out, false).unwrap(); + new_role + .write(&metadata_destination_out, false) + .await + .unwrap(); // reload repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(targets_destination), ) .load() + .await .unwrap(); let metadata_base_url_out = dir_url(&metadata_destination_out); // add outdir to repo //create a new editor with the repo - let mut editor = RepositoryEditor::from_repo(root_path(), new_repo).unwrap(); + let mut editor = RepositoryEditor::from_repo(root_path(), new_repo) + .await + .unwrap(); editor .add_role( "A", metadata_base_url_out.as_str(), PathSet::Paths(vec![PathPattern::new("*.txt").unwrap()]), NonZeroU64::new(1).unwrap(), - Some(key_hash_map(role1_key)), + Some(key_hash_map(role1_key).await), ) + .await .unwrap(); //sign everything since targets key is the same as snapshot and timestamp @@ -516,23 +560,24 @@ fn update_targets_flow() { .snapshot_version(snapshot_version) .timestamp_expires(timestamp_expiration) .timestamp_version(timestamp_version); - let signed = editor.sign(&[Box::new(key_source)]).unwrap(); + let signed = editor.sign(&[Box::new(key_source)]).await.unwrap(); // write repo let new_dir = TempDir::new().unwrap(); let metadata_destination = new_dir.as_ref().join("metadata"); let targets_destination = new_dir.as_ref().join("targets"); - signed.write(&metadata_destination).unwrap(); + signed.write(&metadata_destination).await.unwrap(); // reload repo and verify that A role is included // reload repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); new_repo.delegated_role("A").unwrap(); @@ -543,20 +588,25 @@ fn update_targets_flow() { .version(NonZeroU64::new(1).unwrap()) .expires(Utc::now().checked_add_signed(Duration::days(21)).unwrap()) .sign(role2_key) + .await .unwrap(); // write the role to outdir let outdir = TempDir::new().unwrap(); let metadata_destination_out = outdir.as_ref().join("metadata"); - new_role.write(&metadata_destination_out, false).unwrap(); + new_role + .write(&metadata_destination_out, false) + .await + .unwrap(); // reload repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); let metadata_base_url_out = dir_url(&metadata_destination_out); @@ -571,31 +621,34 @@ fn update_targets_flow() { metadata_base_url_out.as_str(), PathSet::Paths(vec![PathPattern::new("file?.txt").unwrap()]), NonZeroU64::new(1).unwrap(), - Some(key_hash_map(role2_key)), + Some(key_hash_map(role2_key).await), ) + .await .unwrap() .version(NonZeroU64::new(1).unwrap()) .expires(Utc::now().checked_add_signed(Duration::days(21)).unwrap()); // sign A and write A and B metadata to output directory - let signed_roles = editor.sign(role1_key).unwrap(); + let signed_roles = editor.sign(role1_key).await.unwrap(); // write the role to outdir let outdir = TempDir::new().unwrap(); let metadata_destination_out = outdir.as_ref().join("metadata"); signed_roles .write(&metadata_destination_out, false) + .await .unwrap(); // reload repo and add in A and B metadata and update snapshot // reload repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); let metadata_base_url_out = dir_url(&metadata_destination_out); @@ -603,9 +656,12 @@ fn update_targets_flow() { let root_key = key_path(); let key_source = LocalKeySource { path: root_key }; - let mut editor = RepositoryEditor::from_repo(root_path(), new_repo).unwrap(); + let mut editor = RepositoryEditor::from_repo(root_path(), new_repo) + .await + .unwrap(); editor .update_delegated_targets("A", metadata_base_url_out.as_str()) + .await .unwrap(); editor .snapshot_version(NonZeroU64::new(1).unwrap()) @@ -613,7 +669,7 @@ fn update_targets_flow() { .timestamp_version(NonZeroU64::new(1).unwrap()) .timestamp_expires(Utc::now().checked_add_signed(Duration::days(21)).unwrap()); - let signed_refreshed_repo = editor.sign(&[Box::new(key_source)]).unwrap(); + let signed_refreshed_repo = editor.sign(&[Box::new(key_source)]).await.unwrap(); // write repo let end_repo = TempDir::new().unwrap(); @@ -621,16 +677,20 @@ fn update_targets_flow() { let metadata_destination = end_repo.as_ref().join("metadata"); let targets_destination = end_repo.as_ref().join("targets"); - signed_refreshed_repo.write(&metadata_destination).unwrap(); + signed_refreshed_repo + .write(&metadata_destination) + .await + .unwrap(); // reload repo and verify that A and B role are included let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); // verify that role A and B are included @@ -647,46 +707,52 @@ fn update_targets_flow() { let targets = vec![file1]; editor .add_target_paths(targets) + .await .unwrap() .version(targets_version) .expires(targets_expiration); // Sign A metadata - let role = editor.sign(role1_key).unwrap(); + let role = editor.sign(role1_key).await.unwrap(); let outdir = TempDir::new().unwrap(); let metadata_destination_out = outdir.as_ref().join("metadata"); let targets_destination_out = outdir.as_ref().join("targets"); // Write metadata to outdir/metata/A.json - role.write(&metadata_destination_out, false).unwrap(); + role.write(&metadata_destination_out, false).await.unwrap(); // Copy targets to outdir/targets/... role.copy_targets(targets_path(), &targets_destination_out, PathExists::Skip) + .await .unwrap(); // Add in edited A targets and update snapshot (update-repo) // load repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); let metadata_base_url_out = dir_url(&metadata_destination_out); - let mut editor = RepositoryEditor::from_repo(root_path(), new_repo).unwrap(); + let mut editor = RepositoryEditor::from_repo(root_path(), new_repo) + .await + .unwrap(); // update A metadata editor .update_delegated_targets("A", metadata_base_url_out.as_str()) + .await .unwrap() .snapshot_version(snapshot_version) .snapshot_expires(snapshot_expiration) .timestamp_version(timestamp_version) .timestamp_expires(timestamp_expiration); - let signed_repo = editor.sign(targets_key).unwrap(); + let signed_repo = editor.sign(targets_key).await.unwrap(); // write signed repo let end_repo = TempDir::new().unwrap(); @@ -694,54 +760,61 @@ fn update_targets_flow() { let metadata_destination = end_repo.as_ref().join("metadata"); let targets_destination = end_repo.as_ref().join("targets"); - signed_repo.write(&metadata_destination).unwrap(); + signed_repo.write(&metadata_destination).await.unwrap(); signed_repo .copy_targets( &targets_destination_out, &targets_destination, PathExists::Skip, ) + .await .unwrap(); //load the updated repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); let file1 = TargetName::new("file1.txt").unwrap(); assert_eq!( - read_to_end(new_repo.read_target(&file1).unwrap().unwrap()), + read_to_end(new_repo.read_target(&file1).await.unwrap().unwrap()).await, &b"This is an example target file."[..] ); // Edit target "file1.txt" let mut editor = TargetsEditor::from_repo(new_repo, "A").unwrap(); File::create(targets_destination_out.join("file1.txt")) + .await .unwrap() .write_all(b"Updated file1.txt") + .await .unwrap(); let file1 = targets_destination_out.join("file1.txt"); let targets = vec![file1]; editor .add_target_paths(targets) + .await .unwrap() .version(targets_version) .expires(targets_expiration); // Sign A metadata - let role = editor.sign(role1_key).unwrap(); + let role = editor.sign(role1_key).await.unwrap(); let outdir = TempDir::new().unwrap(); let metadata_destination_output = outdir.as_ref().join("metadata"); let targets_destination_output = outdir.as_ref().join("targets"); // Write metadata to outdir/metata/A.json - role.write(&metadata_destination_output, false).unwrap(); + role.write(&metadata_destination_output, false) + .await + .unwrap(); // Copy targets to outdir/targets/... role.link_targets( @@ -749,31 +822,36 @@ fn update_targets_flow() { &targets_destination_output, PathExists::Skip, ) + .await .unwrap(); // Add in edited A targets and update snapshot (update-repo) // load repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); let metadata_base_url_out = dir_url(&metadata_destination_output); let _targets_base_url_out = dir_url(&targets_destination_output); - let mut editor = RepositoryEditor::from_repo(root_path(), new_repo).unwrap(); + let mut editor = RepositoryEditor::from_repo(root_path(), new_repo) + .await + .unwrap(); // add in updated metadata editor .update_delegated_targets("A", metadata_base_url_out.as_str()) + .await .unwrap() .snapshot_version(snapshot_version) .snapshot_expires(snapshot_expiration) .timestamp_version(timestamp_version) .timestamp_expires(timestamp_expiration); - let signed_repo = editor.sign(targets_key).unwrap(); + let signed_repo = editor.sign(targets_key).await.unwrap(); // write signed repo let end_repo = TempDir::new().unwrap(); @@ -781,28 +859,30 @@ fn update_targets_flow() { let metadata_destination = end_repo.as_ref().join("metadata"); let targets_destination = end_repo.as_ref().join("targets"); - signed_repo.write(&metadata_destination).unwrap(); + signed_repo.write(&metadata_destination).await.unwrap(); signed_repo .link_targets( &targets_destination_out, &targets_destination, PathExists::Skip, ) + .await .unwrap(); //load the updated repo let root = root_path(); let new_repo = RepositoryLoader::new( - File::open(root).unwrap(), + &tokio::fs::read(root).await.unwrap(), dir_url(&metadata_destination), dir_url(&targets_destination), ) .load() + .await .unwrap(); let file1 = TargetName::new("file1.txt").unwrap(); assert_eq!( - read_to_end(new_repo.read_target(&file1).unwrap().unwrap()), + read_to_end(new_repo.read_target(&file1).await.unwrap().unwrap()).await, &b"Updated file1.txt"[..] ); } diff --git a/tough/tests/rotated_root.rs b/tough/tests/rotated_root.rs index 58625db7..67733193 100644 --- a/tough/tests/rotated_root.rs +++ b/tough/tests/rotated_root.rs @@ -3,20 +3,20 @@ mod test_utils; -use std::fs::File; use test_utils::{dir_url, test_data}; use tough::RepositoryLoader; -#[test] -fn rotated_root() { +#[tokio::test] +async fn rotated_root() { let base = test_data().join("rotated-root"); let repo = RepositoryLoader::new( - File::open(base.join("1.root.json")).unwrap(), + &tokio::fs::read(base.join("1.root.json")).await.unwrap(), dir_url(&base), dir_url(base.join("targets")), ) .load() + .await .unwrap(); assert_eq!(u64::from(repo.root().signed.version), 2); diff --git a/tough/tests/target_path_safety.rs b/tough/tests/target_path_safety.rs index 61af5874..a7dc3228 100644 --- a/tough/tests/target_path_safety.rs +++ b/tough/tests/target_path_safety.rs @@ -4,11 +4,11 @@ use chrono::{DateTime, TimeZone, Utc}; use maplit::hashmap; use ring::rand::SystemRandom; use std::collections::HashMap; -use std::fs::{self, create_dir_all, File}; use std::num::NonZeroU64; use std::path::Path; use tempfile::TempDir; use test_utils::{dir_url, test_data, DATA_1, DATA_2, DATA_3}; +use tokio::fs; use tough::editor::signed::SignedRole; use tough::editor::RepositoryEditor; use tough::key_source::{KeySource, LocalKeySource}; @@ -23,12 +23,12 @@ fn later() -> DateTime { } /// This test ensures that we can safely handle path-like target names with ../'s in them. -fn create_root(root_path: &Path, consistent_snapshot: bool) -> Vec> { +async fn create_root(root_path: &Path, consistent_snapshot: bool) -> Vec> { let keys: Vec> = vec![Box::new(LocalKeySource { path: test_data().join("snakeoil.pem"), })]; - let key_pair = keys.get(0).unwrap().as_sign().unwrap().tuf_key(); + let key_pair = keys.get(0).unwrap().as_sign().await.unwrap().tuf_key(); let key_id = key_pair.key_id().unwrap(); let empty_keys = RoleKeys { @@ -64,21 +64,24 @@ fn create_root(root_path: &Path, consistent_snapshot: bool) -> Vec>(path: P) -> Url { } /// Gets the goods from a read and makes a Vec -pub fn read_to_end(mut reader: R) -> Vec { - let mut v = Vec::new(); - reader.read_to_end(&mut v).unwrap(); - v +pub async fn read_to_end(mut stream: S) -> Vec +where + E: std::fmt::Debug, + S: IntoVec, +{ + stream.into_vec().await.unwrap() } diff --git a/tough/tests/transport.rs b/tough/tests/transport.rs index fd7d419d..a5f7d70d 100644 --- a/tough/tests/transport.rs +++ b/tough/tests/transport.rs @@ -1,7 +1,7 @@ -use std::fs; use std::str::FromStr; use tempfile::TempDir; use test_utils::read_to_end; +use tokio::fs; use tough::{DefaultTransport, Transport, TransportErrorKind}; use url::Url; @@ -10,11 +10,11 @@ mod test_utils; /// If the `http` feature is not enabled, we should get an error message indicating that the feature /// is not enabled. #[cfg(not(feature = "http"))] -#[test] -fn default_transport_error_no_http() { +#[tokio::test] +async fn default_transport_error_no_http() { let transport = DefaultTransport::new(); let url = Url::from_str("http://example.com").unwrap(); - let error = transport.fetch(url).err().unwrap(); + let error = transport.fetch(url).await.err().unwrap(); match error.kind() { TransportErrorKind::UnsupportedUrlScheme => { let message = format!("{}", error); @@ -24,26 +24,26 @@ fn default_transport_error_no_http() { } } -#[test] -fn default_transport_error_ftp() { +#[tokio::test] +async fn default_transport_error_ftp() { let transport = DefaultTransport::new(); let url = Url::from_str("ftp://example.com").unwrap(); - let error = transport.fetch(url.clone()).err().unwrap(); + let error = transport.fetch(url.clone()).await.err().unwrap(); match error.kind() { TransportErrorKind::UnsupportedUrlScheme => assert_eq!(error.url(), url.as_str()), _ => panic!("incorrect error kind, expected UnsupportedUrlScheme"), } } -#[test] -fn default_transport_file() { +#[tokio::test] +async fn default_transport_file() { let dir = TempDir::new().unwrap(); let filepath = dir.path().join("file.txt"); - fs::write(&filepath, "123123987").unwrap(); + fs::write(&filepath, "123123987").await.unwrap(); let transport = DefaultTransport::new(); let url = Url::from_file_path(filepath).unwrap(); - let read = transport.fetch(url).unwrap(); - let temp_vec = read_to_end(read); + let read = transport.fetch(url).await.unwrap(); + let temp_vec = read_to_end(read).await; let contents = String::from_utf8_lossy(&temp_vec); assert_eq!(contents, "123123987"); } diff --git a/tuftool/Cargo.toml b/tuftool/Cargo.toml index 0e2fb38e..f9ff405e 100644 --- a/tuftool/Cargo.toml +++ b/tuftool/Cargo.toml @@ -21,19 +21,21 @@ aws-sdk-kms = "0.28" aws-sdk-ssm = "0.28" chrono = { version = "0.4", default-features = false, features = ["alloc", "std", "clock"] } clap = { version = "4", features = ["derive"] } +futures = "0.3.28" hex = "0.4" log = "0.4" maplit = "1" olpc-cjson = { version = "0.1", path = "../olpc-cjson" } pem = "3" rayon = "1" -reqwest = { version = "0.11", default-features = false, features = ["blocking", "rustls-tls"] } +reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } ring = { version = "0.16", features = ["std"] } serde = "1" serde_json = "1" simplelog = "0.12" snafu = { version = "0.7", features = ["backtraces-impl-backtrace-crate"] } tempfile = "3" +tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } tough = { version = "0.14", path = "../tough", features = ["http"] } tough-kms = { version = "0.6", path = "../tough-kms" } tough-ssm = { version = "0.9", path = "../tough-ssm" } @@ -42,4 +44,6 @@ walkdir = "2" [dev-dependencies] assert_cmd = "2" +futures = "0.3.28" +futures-core = "0.3.28" httptest = "0.15" diff --git a/tuftool/src/add_key_role.rs b/tuftool/src/add_key_role.rs index e9726999..35d6e295 100644 --- a/tuftool/src/add_key_role.rs +++ b/tuftool/src/add_key_role.rs @@ -51,24 +51,26 @@ pub(crate) struct AddKeyArgs { } impl AddKeyArgs { - pub(crate) fn run(&self, role: &str) -> Result<()> { + pub(crate) async fn run(&self, role: &str) -> Result<()> { // load the repo - let repository = load_metadata_repo(&self.root, self.metadata_base_url.clone())?; + let repository = load_metadata_repo(&self.root, self.metadata_base_url.clone()).await?; self.add_key( role, TargetsEditor::from_repo(repository, role) .context(error::EditorFromRepoSnafu { path: &self.root })?, ) + .await } /// Adds keys to a role using targets Editor - fn add_key(&self, role: &str, mut editor: TargetsEditor) -> Result<()> { + async fn add_key(&self, role: &str, mut editor: TargetsEditor) -> Result<()> { // create the keypairs to add let mut key_pairs = HashMap::new(); for source in &self.new_keys { let key_source = parse_key_source(source)?; let key_pair = key_source .as_sign() + .await .context(error::KeyPairFromKeySourceSnafu)? .tuf_key(); key_pairs.insert( @@ -92,10 +94,12 @@ impl AddKeyArgs { .version(self.version) .expires(self.expires) .sign(&keys) + .await .context(error::SignRepoSnafu)?; let metadata_destination_out = &self.outdir.join("metadata"); updated_role .write(metadata_destination_out, false) + .await .context(error::WriteRolesSnafu { roles: [role.to_string()].to_vec(), })?; diff --git a/tuftool/src/add_role.rs b/tuftool/src/add_role.rs index 9da9df97..a097fe73 100644 --- a/tuftool/src/add_role.rs +++ b/tuftool/src/add_role.rs @@ -83,17 +83,19 @@ pub(crate) struct AddRoleArgs { } impl AddRoleArgs { - pub(crate) fn run(&self, role: &str) -> Result<()> { + pub(crate) async fn run(&self, role: &str) -> Result<()> { // load the repo - let repository = load_metadata_repo(&self.root, self.metadata_base_url.clone())?; + let repository = load_metadata_repo(&self.root, self.metadata_base_url.clone()).await?; // if sign_all use Repository Editor to sign the entire repo if not use targets editor if self.sign_all { // Add a role using a `RepositoryEditor` self.with_repo_editor( role, RepositoryEditor::from_repo(&self.root, repository) + .await .context(error::EditorFromRepoSnafu { path: &self.root })?, ) + .await } else { // Add a role using a `TargetsEditor` self.add_role( @@ -101,12 +103,13 @@ impl AddRoleArgs { TargetsEditor::from_repo(repository, role) .context(error::EditorFromRepoSnafu { path: &self.root })?, ) + .await } } #[allow(clippy::option_if_let_else)] /// Adds a role to metadata using targets Editor - fn add_role(&self, role: &str, mut editor: TargetsEditor) -> Result<()> { + async fn add_role(&self, role: &str, mut editor: TargetsEditor) -> Result<()> { let paths = if let Some(paths) = &self.paths { PathSet::Paths(paths.clone()) } else if let Some(path_hash_prefixes) = &self.path_hash_prefixes { @@ -130,14 +133,17 @@ impl AddRoleArgs { self.threshold, None, ) + .await .context(error::LoadMetadataSnafu)? .version(self.version) .expires(self.expires) .sign(&keys) + .await .context(error::SignRepoSnafu)?; let metadata_destination_out = &self.outdir.join("metadata"); updated_role .write(metadata_destination_out, false) + .await .context(error::WriteRolesSnafu { roles: [self.delegatee.clone(), role.to_string()].to_vec(), })?; @@ -147,7 +153,7 @@ impl AddRoleArgs { #[allow(clippy::option_if_let_else)] /// Adds a role to metadata using repo Editor - fn with_repo_editor(&self, role: &str, mut editor: RepositoryEditor) -> Result<()> { + async fn with_repo_editor(&self, role: &str, mut editor: RepositoryEditor) -> Result<()> { let mut keys = Vec::new(); for source in &self.keys { let key_source = parse_key_source(source)?; @@ -183,6 +189,7 @@ impl AddRoleArgs { .targets_expires(self.expires) .context(error::DelegationStructureSnafu)? .sign_targets_editor(&keys) + .await .context(error::DelegateeNotFoundSnafu { role: role.to_string(), })?; @@ -201,6 +208,7 @@ impl AddRoleArgs { self.threshold, None, ) + .await .context(error::LoadMetadataSnafu)? .targets_version(self.version) .context(error::DelegationStructureSnafu)? @@ -211,10 +219,11 @@ impl AddRoleArgs { .timestamp_version(timestamp_version) .timestamp_expires(timestamp_expires); - let signed_repo = editor.sign(&keys).context(error::SignRepoSnafu)?; + let signed_repo = editor.sign(&keys).await.context(error::SignRepoSnafu)?; let metadata_destination_out = &self.outdir.join("metadata"); signed_repo .write(metadata_destination_out) + .await .context(error::WriteRolesSnafu { roles: [self.delegatee.clone(), role.to_string()].to_vec(), })?; diff --git a/tuftool/src/clone.rs b/tuftool/src/clone.rs index 10f4846e..8cf1479f 100644 --- a/tuftool/src/clone.rs +++ b/tuftool/src/clone.rs @@ -6,7 +6,6 @@ use crate::download_root::download_root; use crate::error::{self, Result}; use clap::Parser; use snafu::ResultExt; -use std::fs::File; use std::num::NonZeroU64; use std::path::PathBuf; use tough::{ExpirationEnforcement, RepositoryLoader}; @@ -64,13 +63,13 @@ WARNING: repo metadata is expired, meaning the owner hasn't verified its content } impl CloneArgs { - pub(crate) fn run(&self) -> Result<()> { + pub(crate) async fn run(&self) -> Result<()> { // Use local root.json or download from repository let root_path = if let Some(path) = &self.root { PathBuf::from(path) } else if self.allow_root_download { let outdir = std::env::current_dir().context(error::CurrentDirSnafu)?; - download_root(&self.metadata_base_url, self.root_version, outdir)? + download_root(&self.metadata_base_url, self.root_version, outdir).await? } else { eprintln!("No root.json available"); std::process::exit(1); @@ -96,12 +95,15 @@ impl CloneArgs { ExpirationEnforcement::Safe }; let repository = RepositoryLoader::new( - File::open(&root_path).context(error::OpenRootSnafu { path: &root_path })?, + &tokio::fs::read(&root_path) + .await + .context(error::OpenRootSnafu { path: &root_path })?, self.metadata_base_url.clone(), targets_base_url, ) .expiration_enforcement(expiration_enforcement) .load() + .await .context(error::RepoLoadSnafu)?; // Clone the repository, downloading none, all, or a subset of targets @@ -109,6 +111,7 @@ impl CloneArgs { println!("Cloning repository metadata to {:?}", self.metadata_dir); repository .cache_metadata(&self.metadata_dir, true) + .await .context(error::CloneRepositorySnafu)?; } else { // Similar to `targets_base_url, structopt's guard rails won't let us have a @@ -125,6 +128,7 @@ impl CloneArgs { if self.target_names.is_empty() { repository .cache(&self.metadata_dir, targets_dir, None::<&[&str]>, true) + .await .context(error::CloneRepositorySnafu)?; } else { repository @@ -134,6 +138,7 @@ impl CloneArgs { Some(self.target_names.as_slice()), true, ) + .await .context(error::CloneRepositorySnafu)?; } }; diff --git a/tuftool/src/common.rs b/tuftool/src/common.rs index b9a182db..5a7dcd3f 100644 --- a/tuftool/src/common.rs +++ b/tuftool/src/common.rs @@ -1,7 +1,6 @@ /// This module is for code that is re-used by different `tuftool` subcommands. use crate::error::{self, Result}; use snafu::ResultExt; -use std::fs::File; use std::path::Path; use tough::{Repository, RepositoryLoader}; use url::Url; @@ -17,13 +16,15 @@ pub(crate) const UNUSED_URL: &str = "file:///unused/url"; /// - `root` must be a path to a file that can be opened with `File::open`. /// - `metadata_url` can be local or remote. /// -pub(crate) fn load_metadata_repo

(root: P, metadata_url: Url) -> Result +pub(crate) async fn load_metadata_repo

(root: P, metadata_url: Url) -> Result where P: AsRef, { let root = root.as_ref(); RepositoryLoader::new( - File::open(root).context(error::OpenRootSnafu { path: root })?, + &tokio::fs::read(root) + .await + .context(error::OpenRootSnafu { path: root })?, metadata_url, // we don't do anything with the targets url for metadata operations Url::parse(UNUSED_URL).with_context(|_| error::UrlParseSnafu { @@ -31,5 +32,6 @@ where })?, ) .load() + .await .context(error::RepoLoadSnafu) } diff --git a/tuftool/src/create.rs b/tuftool/src/create.rs index 228725b1..586f29e5 100644 --- a/tuftool/src/create.rs +++ b/tuftool/src/create.rs @@ -78,7 +78,7 @@ pub(crate) struct CreateArgs { } impl CreateArgs { - pub(crate) fn run(&self) -> Result<()> { + pub(crate) async fn run(&self) -> Result<()> { let mut keys = Vec::new(); for source in &self.keys { let key_source = parse_key_source(source)?; @@ -94,8 +94,9 @@ impl CreateArgs { .context(error::InitializeThreadPoolSnafu)?; } - let targets = build_targets(&self.targets_indir, self.follow)?; + let targets = build_targets(&self.targets_indir, self.follow).await?; let mut editor = RepositoryEditor::new(&self.root) + .await .context(error::EditorCreateSnafu { path: &self.root })?; editor @@ -114,18 +115,20 @@ impl CreateArgs { .context(error::DelegationStructureSnafu)?; } - let signed_repo = editor.sign(&keys).context(error::SignRepoSnafu)?; + let signed_repo = editor.sign(&keys).await.context(error::SignRepoSnafu)?; let metadata_dir = &self.outdir.join("metadata"); let targets_outdir = &self.outdir.join("targets"); signed_repo .link_targets(&self.targets_indir, targets_outdir, self.target_path_exists) + .await .context(error::LinkTargetsSnafu { indir: &self.targets_indir, outdir: targets_outdir, })?; signed_repo .write(metadata_dir) + .await .context(error::WriteRepoSnafu { directory: metadata_dir, })?; diff --git a/tuftool/src/create_role.rs b/tuftool/src/create_role.rs index e01879a1..9b7f32ca 100644 --- a/tuftool/src/create_role.rs +++ b/tuftool/src/create_role.rs @@ -37,7 +37,7 @@ pub(crate) struct CreateRoleArgs { } impl CreateRoleArgs { - pub(crate) fn run(&self, role: &str) -> Result<()> { + pub(crate) async fn run(&self, role: &str) -> Result<()> { let mut keys = Vec::new(); for source in &self.keys { let key_source = parse_key_source(source)?; @@ -48,14 +48,16 @@ impl CreateRoleArgs { let new_role = TargetsEditor::new(role) .version(self.version) .expires(self.expires) - .add_key(key_hash_map(&keys), None) + .add_key(key_hash_map(&keys).await, None) .context(error::DelegationStructureSnafu)? .sign(&keys) + .await .context(error::SignRepoSnafu)?; // write the new role let metadata_destination_out = &self.outdir.join("metadata"); new_role .write(metadata_destination_out, false) + .await .context(error::WriteRolesSnafu { roles: [role.to_string()].to_vec(), })?; @@ -63,10 +65,10 @@ impl CreateRoleArgs { } } -fn key_hash_map(keys: &[Box]) -> HashMap, Key> { +async fn key_hash_map(keys: &[Box]) -> HashMap, Key> { let mut key_pairs = HashMap::new(); for source in keys { - let key_pair = source.as_sign().unwrap().tuf_key(); + let key_pair = source.as_sign().await.unwrap().tuf_key(); key_pairs.insert(key_pair.key_id().unwrap().clone(), key_pair.clone()); } key_pairs diff --git a/tuftool/src/download.rs b/tuftool/src/download.rs index 6492e23c..e660b19a 100644 --- a/tuftool/src/download.rs +++ b/tuftool/src/download.rs @@ -5,7 +5,6 @@ use crate::download_root::download_root; use crate::error::{self, Result}; use clap::Parser; use snafu::{ensure, ResultExt}; -use std::fs::File; use std::num::NonZeroU64; use std::path::{Path, PathBuf}; use tough::{ExpirationEnforcement, Prefix, Repository, RepositoryLoader, TargetName}; @@ -56,7 +55,7 @@ WARNING: `--allow-expired-repo` was passed; this is unsafe and will not establis } impl DownloadArgs { - pub(crate) fn run(&self) -> Result<()> { + pub(crate) async fn run(&self) -> Result<()> { // To help ensure that downloads are safe, we require that the outdir does not exist. ensure!( !self.outdir.exists(), @@ -68,7 +67,7 @@ impl DownloadArgs { PathBuf::from(path) } else if self.allow_root_download { let outdir = std::env::current_dir().context(error::CurrentDirSnafu)?; - download_root(&self.metadata_base_url, self.root_version, outdir)? + download_root(&self.metadata_base_url, self.root_version, outdir).await? } else { eprintln!("No root.json available"); std::process::exit(1); @@ -82,29 +81,37 @@ impl DownloadArgs { ExpirationEnforcement::Safe }; let repository = RepositoryLoader::new( - File::open(&root_path).context(error::OpenRootSnafu { path: &root_path })?, + &tokio::fs::read(&root_path) + .await + .context(error::OpenRootSnafu { path: &root_path })?, self.metadata_base_url.clone(), self.targets_base_url.clone(), ) .expiration_enforcement(expiration_enforcement) .load() + .await .context(error::RepoLoadSnafu)?; // download targets - handle_download(&repository, &self.outdir, &self.target_names) + handle_download(&repository, &self.outdir, &self.target_names).await } } -fn handle_download(repository: &Repository, outdir: &Path, raw_names: &[String]) -> Result<()> { +async fn handle_download( + repository: &Repository, + outdir: &Path, + raw_names: &[String], +) -> Result<()> { let target_names: Result> = raw_names .iter() .map(|s| TargetName::new(s).context(error::InvalidTargetNameSnafu)) .collect(); let target_names = target_names?; - let download_target = |name: &TargetName| -> Result<()> { + let download_target = |name: TargetName| async move { println!("\t-> {}", name.raw()); repository - .save_target(name, outdir, Prefix::None) + .save_target(&name, outdir, Prefix::None) + .await .context(error::MetadataSnafu)?; Ok(()) }; @@ -123,9 +130,11 @@ fn handle_download(repository: &Repository, outdir: &Path, raw_names: &[String]) }; println!("Downloading targets to {outdir:?}"); - std::fs::create_dir_all(outdir).context(error::DirCreateSnafu { path: outdir })?; + tokio::fs::create_dir_all(outdir) + .await + .context(error::DirCreateSnafu { path: outdir })?; for target in targets { - download_target(&target)?; + download_target(target).await?; } Ok(()) } diff --git a/tuftool/src/download_root.rs b/tuftool/src/download_root.rs index aa04a73c..5d59c9ba 100644 --- a/tuftool/src/download_root.rs +++ b/tuftool/src/download_root.rs @@ -3,15 +3,17 @@ //! The `download_root` module owns the logic for downloading a given version of `root.json`. use crate::error::{self, Result}; +use futures::StreamExt; use snafu::ResultExt; -use std::fs::File; use std::num::NonZeroU64; use std::path::{Path, PathBuf}; +use tokio::fs::File; +use tokio::io::AsyncWriteExt; use url::Url; /// Download the given version of `root.json` /// This is an unsafe operation, and doesn't establish trust. It should only be used for testing! -pub(crate) fn download_root

( +pub(crate) async fn download_root

( metadata_base_url: &Url, version: NonZeroU64, outdir: P, @@ -29,15 +31,23 @@ where })?; root_warning(&path); - let mut root_request = reqwest::blocking::get(url.as_str()) + let root_request = reqwest::get(url.as_str()) + .await .context(error::ReqwestGetSnafu)? .error_for_status() .context(error::BadResponseSnafu { url })?; - let mut f = File::create(&path).context(error::OpenFileSnafu { path: &path })?; - root_request - .copy_to(&mut f) - .context(error::ReqwestCopySnafu)?; + let mut f = File::create(&path) + .await + .context(error::OpenFileSnafu { path: &path })?; + + let bytes_stream = &mut root_request.bytes_stream(); + while let Some(bytes) = bytes_stream.next().await { + let bytes = bytes.context(error::ReqwestCopySnafu)?; + f.write_all(&bytes) + .await + .with_context(|_| error::FileWriteSnafu { path: path.clone() })?; + } Ok(path) } diff --git a/tuftool/src/error.rs b/tuftool/src/error.rs index f3571065..6baaed06 100644 --- a/tuftool/src/error.rs +++ b/tuftool/src/error.rs @@ -378,6 +378,12 @@ pub(crate) enum Error { source: std::io::Error, backtrace: Backtrace, }, + + #[snafu(display("Failed to join a task: {}", source))] + JoinTask { + source: tokio::task::JoinError, + backtrace: Backtrace, + }, } // Extracts the status code from a reqwest::Error and converts it to a string to be displayed diff --git a/tuftool/src/main.rs b/tuftool/src/main.rs index 345457e6..55434698 100644 --- a/tuftool/src/main.rs +++ b/tuftool/src/main.rs @@ -32,14 +32,13 @@ mod update_targets; use crate::error::Result; use clap::Parser; -use rayon::prelude::*; +use futures::{StreamExt, TryStreamExt}; use simplelog::{ColorChoice, ConfigBuilder, LevelFilter, TermLogger, TerminalMode}; use snafu::{ErrorCompat, OptionExt, ResultExt}; use std::collections::HashMap; -use std::fs::File; -use std::io::Write; use std::path::Path; use tempfile::NamedTempFile; +use tokio::io::AsyncWriteExt; use tough::schema::Target; use tough::TargetName; use walkdir::WalkDir; @@ -57,7 +56,7 @@ struct Program { } impl Program { - fn run(self) -> Result<()> { + async fn run(self) -> Result<()> { TermLogger::init( self.log_level, ConfigBuilder::new() @@ -68,7 +67,7 @@ impl Program { ColorChoice::Auto, ) .context(error::LoggerSnafu)?; - self.cmd.run() + self.cmd.run().await } } @@ -92,70 +91,112 @@ enum Command { } impl Command { - fn run(self) -> Result<()> { + async fn run(self) -> Result<()> { match self { - Command::Create(args) => args.run(), - Command::Root(root_subcommand) => root_subcommand.run(), - Command::Download(args) => args.run(), - Command::Update(args) => args.run(), - Command::Delegation(cmd) => cmd.run(), - Command::Clone(cmd) => cmd.run(), - Command::TransferMetadata(cmd) => cmd.run(), + Command::Create(args) => args.run().await, + Command::Root(root_subcommand) => root_subcommand.run().await, + Command::Download(args) => args.run().await, + Command::Update(args) => args.run().await, + Command::Delegation(cmd) => cmd.run().await, + Command::Clone(cmd) => cmd.run().await, + Command::TransferMetadata(cmd) => cmd.run().await, } } } -fn load_file(path: &Path) -> Result +async fn load_file(path: &Path) -> Result where for<'de> T: serde::Deserialize<'de>, { - serde_json::from_reader(File::open(path).context(error::FileOpenSnafu { path })?) - .context(error::FileParseJsonSnafu { path }) + serde_json::from_slice( + &tokio::fs::read(path) + .await + .context(error::FileOpenSnafu { path })?, + ) + .context(error::FileParseJsonSnafu { path }) } -fn write_file(path: &Path, json: &T) -> Result<()> +async fn write_file(path: &Path, json: &T) -> Result<()> where T: serde::Serialize, { // Use `tempfile::NamedTempFile::persist` to perform an atomic file write. let parent = path.parent().context(error::PathParentSnafu { path })?; - let mut writer = + let file = NamedTempFile::new_in(parent).context(error::FileTempCreateSnafu { path: parent })?; - serde_json::to_writer_pretty(&mut writer, json).context(error::FileWriteJsonSnafu { path })?; - writer - .write_all(b"\n") + + let (file, tmp_path) = file.into_parts(); + let mut file = tokio::fs::File::from_std(file); + + let buf = serde_json::to_vec_pretty(json).context(error::FileWriteJsonSnafu { path })?; + file.write_all(&buf) + .await .context(error::FileWriteSnafu { path })?; - writer + + let file = file.into_std().await; + NamedTempFile::from_parts(file, tmp_path) .persist(path) .context(error::FilePersistSnafu { path })?; + Ok(()) } // Walk the directory specified, building a map of filename to Target structs. // Hashing of the targets is done in parallel -fn build_targets

(indir: P, follow_links: bool) -> Result> +async fn build_targets

(indir: P, follow_links: bool) -> Result> where P: AsRef, { - let indir = indir.as_ref(); - WalkDir::new(indir) - .follow_links(follow_links) - .into_iter() - .par_bridge() - .filter_map(|entry| match entry { - Ok(entry) => { - if entry.file_type().is_file() { - Some(process_target(entry.path())) - } else { - None + let indir = indir.as_ref().to_owned(); + + let (tx, rx) = tokio::sync::mpsc::channel(10); + let indir_clone = indir.clone(); + tokio::task::spawn_blocking(move || -> Result<()> { + let walker = WalkDir::new(indir_clone.clone()).follow_links(follow_links); + + for entry in walker { + if tx.blocking_send(entry).is_err() { + // Receiver error'ed out + break; + }; + } + Ok(()) + }); + + // Spawn tasks to process targets concurrently. + let join_handles = + futures::stream::unfold( + rx, + move |mut rx| async move { Some((rx.recv().await?, rx)) }, + ) + .filter_map(|entry| { + let indir = indir.clone(); + async move { + match entry { + Ok(entry) => { + if entry.file_type().is_file() { + let future = async move { process_target(entry.path()).await }; + Some(Ok(tokio::task::spawn(future))) + } else { + None + } + } + Err(err) => Some(Err(err).context(error::WalkDirSnafu { directory: indir })), } } - Err(err) => Some(Err(err).context(error::WalkDirSnafu { directory: indir })), }) + .try_collect::>() + .await?; + + // Await all tasks. + futures::future::try_join_all(join_handles) + .await + .context(error::JoinTaskSnafu {})? + .into_iter() .collect() } -fn process_target(path: &Path) -> Result<(TargetName, Target)> { +async fn process_target(path: &Path) -> Result<(TargetName, Target)> { // Get the file name as a TargetName let target_name = TargetName::new( path.file_name() @@ -166,13 +207,16 @@ fn process_target(path: &Path) -> Result<(TargetName, Target)> { .context(error::InvalidTargetNameSnafu)?; // Build a Target from the path given. If it is not a file, this will fail - let target = Target::from_path(path).context(error::TargetFromPathSnafu { path })?; + let target = Target::from_path(path) + .await + .context(error::TargetFromPathSnafu { path })?; Ok((target_name, target)) } -fn main() -> ! { - std::process::exit(match Program::parse().run() { +#[tokio::main] +async fn main() -> ! { + std::process::exit(match Program::parse().run().await { Ok(()) => 0, Err(err) => { eprintln!("{err}"); @@ -199,8 +243,8 @@ struct Delegation { } impl Delegation { - fn run(self) -> Result<()> { - self.cmd.run(&self.role) + async fn run(self) -> Result<()> { + self.cmd.run(&self.role).await } } @@ -221,14 +265,14 @@ enum DelegationCommand { } impl DelegationCommand { - fn run(self, role: &str) -> Result<()> { + async fn run(self, role: &str) -> Result<()> { match self { - DelegationCommand::CreateRole(args) => args.run(role), - DelegationCommand::AddRole(args) => args.run(role), - DelegationCommand::UpdateDelegatedTargets(args) => args.run(role), - DelegationCommand::AddKey(args) => args.run(role), - DelegationCommand::RemoveKey(args) => args.run(role), - DelegationCommand::Remove(args) => args.run(role), + DelegationCommand::CreateRole(args) => args.run(role).await, + DelegationCommand::AddRole(args) => args.run(role).await, + DelegationCommand::UpdateDelegatedTargets(args) => args.run(role).await, + DelegationCommand::AddKey(args) => args.run(role).await, + DelegationCommand::RemoveKey(args) => args.run(role).await, + DelegationCommand::Remove(args) => args.run(role).await, } } } diff --git a/tuftool/src/remove_key_role.rs b/tuftool/src/remove_key_role.rs index 404d8665..abd3fb6d 100644 --- a/tuftool/src/remove_key_role.rs +++ b/tuftool/src/remove_key_role.rs @@ -51,17 +51,18 @@ pub(crate) struct RemoveKeyArgs { } impl RemoveKeyArgs { - pub(crate) fn run(&self, role: &str) -> Result<()> { - let repository = load_metadata_repo(&self.root, self.metadata_base_url.clone())?; + pub(crate) async fn run(&self, role: &str) -> Result<()> { + let repository = load_metadata_repo(&self.root, self.metadata_base_url.clone()).await?; self.remove_key( role, TargetsEditor::from_repo(repository, role) .context(error::EditorFromRepoSnafu { path: &self.root })?, ) + .await } /// Removes keys from a delegated role using targets Editor - fn remove_key(&self, role: &str, mut editor: TargetsEditor) -> Result<()> { + async fn remove_key(&self, role: &str, mut editor: TargetsEditor) -> Result<()> { let mut keys = Vec::new(); for source in &self.keys { let key_source = parse_key_source(source)?; @@ -74,10 +75,12 @@ impl RemoveKeyArgs { .version(self.version) .expires(self.expires) .sign(&keys) + .await .context(error::SignRepoSnafu)?; let metadata_destination_out = &self.outdir.join("metadata"); updated_role .write(metadata_destination_out, false) + .await .context(error::WriteRolesSnafu { roles: [role.to_string()].to_vec(), })?; diff --git a/tuftool/src/remove_role.rs b/tuftool/src/remove_role.rs index b3835e22..1856a0fe 100644 --- a/tuftool/src/remove_role.rs +++ b/tuftool/src/remove_role.rs @@ -50,17 +50,18 @@ pub(crate) struct RemoveRoleArgs { } impl RemoveRoleArgs { - pub(crate) fn run(&self, role: &str) -> Result<()> { - let repository = load_metadata_repo(&self.root, self.metadata_base_url.clone())?; + pub(crate) async fn run(&self, role: &str) -> Result<()> { + let repository = load_metadata_repo(&self.root, self.metadata_base_url.clone()).await?; self.remove_delegated_role( role, TargetsEditor::from_repo(repository, role) .context(error::EditorFromRepoSnafu { path: &self.root })?, ) + .await } /// Removes a delegated role from a `Targets` role using `TargetsEditor` - fn remove_delegated_role(&self, role: &str, mut editor: TargetsEditor) -> Result<()> { + async fn remove_delegated_role(&self, role: &str, mut editor: TargetsEditor) -> Result<()> { let mut keys = Vec::new(); for source in &self.keys { let key_source = parse_key_source(source)?; @@ -73,10 +74,12 @@ impl RemoveRoleArgs { .version(self.version) .expires(self.expires) .sign(&keys) + .await .context(error::SignRepoSnafu)?; let metadata_destination_out = &self.outdir.join("metadata"); updated_role .write(metadata_destination_out, false) + .await .context(error::WriteRolesSnafu { roles: [role.to_string()].to_vec(), })?; diff --git a/tuftool/src/root.rs b/tuftool/src/root.rs index 36a1c71a..c9239209 100644 --- a/tuftool/src/root.rs +++ b/tuftool/src/root.rs @@ -132,30 +132,32 @@ macro_rules! role_keys { } impl Command { - pub(crate) fn run(self) -> Result<()> { + pub(crate) async fn run(self) -> Result<()> { match self { - Command::Init { path, version } => Command::init(&path, version), - Command::BumpVersion { path } => Command::bump_version(&path), - Command::Expire { path, time } => Command::expire(&path, &time), + Command::Init { path, version } => Command::init(&path, version).await, + Command::BumpVersion { path } => Command::bump_version(&path).await, + Command::Expire { path, time } => Command::expire(&path, &time).await, Command::SetThreshold { path, role, threshold, - } => Command::set_threshold(&path, role, threshold), - Command::SetVersion { path, version } => Command::set_version(&path, version), + } => Command::set_threshold(&path, role, threshold).await, + Command::SetVersion { path, version } => Command::set_version(&path, version).await, Command::AddKey { path, roles, key_source, - } => Command::add_key(&path, &roles, &key_source), - Command::RemoveKey { path, key_id, role } => Command::remove_key(&path, &key_id, role), + } => Command::add_key(&path, &roles, &key_source).await, + Command::RemoveKey { path, key_id, role } => { + Command::remove_key(&path, &key_id, role).await + } Command::GenRsaKey { path, roles, key_source, bits, exponent, - } => Command::gen_rsa_key(&path, &roles, &key_source, bits, exponent), + } => Command::gen_rsa_key(&path, &roles, &key_source, bits, exponent).await, Command::Sign { path, key_sources, @@ -167,12 +169,12 @@ impl Command { let key_source = parse_key_source(source)?; keys.push(key_source); } - Command::sign(&path, &keys, cross_sign, ignore_threshold) + Command::sign(&path, &keys, cross_sign, ignore_threshold).await } } } - fn init(path: &Path, version: Option) -> Result<()> { + async fn init(path: &Path, version: Option) -> Result<()> { let init_version = version.unwrap_or(1); write_file( path, @@ -194,10 +196,11 @@ impl Command { signatures: Vec::new(), }, ) + .await } - fn bump_version(path: &Path) -> Result<()> { - let mut root: Signed = load_file(path)?; + async fn bump_version(path: &Path) -> Result<()> { + let mut root: Signed = load_file(path).await?; root.signed.version = NonZeroU64::new( root.signed .version @@ -207,58 +210,59 @@ impl Command { ) .context(error::VersionZeroSnafu)?; clear_sigs(&mut root); - write_file(path, &root) + write_file(path, &root).await } - fn expire(path: &Path, time: &DateTime) -> Result<()> { - let mut root: Signed = load_file(path)?; + async fn expire(path: &Path, time: &DateTime) -> Result<()> { + let mut root: Signed = load_file(path).await?; root.signed.expires = round_time(*time); clear_sigs(&mut root); - write_file(path, &root) + write_file(path, &root).await } - fn set_threshold(path: &Path, role: RoleType, threshold: NonZeroU64) -> Result<()> { - let mut root: Signed = load_file(path)?; + async fn set_threshold(path: &Path, role: RoleType, threshold: NonZeroU64) -> Result<()> { + let mut root: Signed = load_file(path).await?; root.signed .roles .entry(role) .and_modify(|rk| rk.threshold = threshold) .or_insert_with(|| role_keys!(threshold)); clear_sigs(&mut root); - write_file(path, &root) + write_file(path, &root).await } - fn set_version(path: &Path, version: NonZeroU64) -> Result<()> { - let mut root: Signed = load_file(path)?; + async fn set_version(path: &Path, version: NonZeroU64) -> Result<()> { + let mut root: Signed = load_file(path).await?; root.signed.version = version; clear_sigs(&mut root); - write_file(path, &root) + write_file(path, &root).await } #[allow(clippy::borrowed_box)] - fn add_key(path: &Path, roles: &[RoleType], key_source: &Vec) -> Result<()> { + async fn add_key(path: &Path, roles: &[RoleType], key_source: &Vec) -> Result<()> { let mut keys = Vec::new(); for source in key_source { let key_source = parse_key_source(source)?; keys.push(key_source); } - let mut root: Signed = load_file(path)?; + let mut root: Signed = load_file(path).await?; clear_sigs(&mut root); for ks in keys { let key_pair = ks .as_sign() + .await .context(error::KeyPairFromKeySourceSnafu)? .tuf_key(); let key_id = hex::encode(add_key(&mut root.signed, roles, key_pair)?); println!("Added key: {key_id}"); } - write_file(path, &root) + write_file(path, &root).await } - fn remove_key(path: &Path, key_id: &Decoded, role: Option) -> Result<()> { - let mut root: Signed = load_file(path)?; + async fn remove_key(path: &Path, key_id: &Decoded, role: Option) -> Result<()> { + let mut root: Signed = load_file(path).await?; if let Some(role) = role { if let Some(role_keys) = root.signed.roles.get_mut(&role) { role_keys @@ -278,18 +282,18 @@ impl Command { root.signed.keys.remove(key_id); } clear_sigs(&mut root); - write_file(path, &root) + write_file(path, &root).await } #[allow(clippy::borrowed_box)] - fn gen_rsa_key( + async fn gen_rsa_key( path: &Path, roles: &[RoleType], key_source: &str, bits: u16, exponent: u32, ) -> Result<()> { - let mut root: Signed = load_file(path)?; + let mut root: Signed = load_file(path).await?; // ring doesn't support RSA key generation yet // https://github.com/briansmith/ring/issues/219 @@ -317,23 +321,24 @@ impl Command { let key_id = hex::encode(add_key(&mut root.signed, roles, key_pair.tuf_key())?); let key = parse_key_source(key_source)?; key.write(&stdout, &key_id) + .await .context(error::WriteKeySourceSnafu)?; clear_sigs(&mut root); println!("{key_id}"); - write_file(path, &root) + write_file(path, &root).await } - fn sign( + async fn sign( path: &Path, key_source: &[Box], cross_sign: Option, ignore_threshold: bool, ) -> Result<()> { - let root: Signed = load_file(path)?; + let root: Signed = load_file(path).await?; // get the root based on cross-sign let loaded_root = match cross_sign { None => root.clone(), - Some(cross_sign_root) => load_file(&cross_sign_root)?, + Some(cross_sign_root) => load_file(&cross_sign_root).await?, }; // sign the root let mut signed_root = SignedRole::new( @@ -342,6 +347,7 @@ impl Command { key_source, &SystemRandom::new(), ) + .await .context(error::SignRootSnafu { path })?; // append the existing signatures if present diff --git a/tuftool/src/transfer_metadata.rs b/tuftool/src/transfer_metadata.rs index f492553b..3590c812 100644 --- a/tuftool/src/transfer_metadata.rs +++ b/tuftool/src/transfer_metadata.rs @@ -7,7 +7,6 @@ use crate::source::parse_key_source; use chrono::{DateTime, Utc}; use clap::Parser; use snafu::ResultExt; -use std::fs::File; use std::num::NonZeroU64; use std::path::{Path, PathBuf}; use tough::editor::RepositoryEditor; @@ -81,7 +80,7 @@ WARNING: `--allow-expired-repo` was passed; this is unsafe and will not establis } impl TransferMetadataArgs { - pub(crate) fn run(&self) -> Result<()> { + pub(crate) async fn run(&self) -> Result<()> { let mut keys = Vec::new(); for source in &self.keys { let key_source = parse_key_source(source)?; @@ -99,17 +98,21 @@ impl TransferMetadataArgs { ExpirationEnforcement::Safe }; let current_repo = RepositoryLoader::new( - File::open(current_root).context(error::OpenRootSnafu { - path: ¤t_root, - })?, + &tokio::fs::read(current_root) + .await + .context(error::OpenRootSnafu { + path: ¤t_root, + })?, self.metadata_base_url.clone(), self.targets_base_url.clone(), ) .expiration_enforcement(expiration_enforcement) .load() + .await .context(error::RepoLoadSnafu)?; let mut editor = RepositoryEditor::new(new_root) + .await .context(error::EditorCreateSnafu { path: &new_root })?; editor @@ -129,11 +132,12 @@ impl TransferMetadataArgs { .context(error::DelegationStructureSnafu)?; } - let signed_repo = editor.sign(&keys).context(error::SignRepoSnafu)?; + let signed_repo = editor.sign(&keys).await.context(error::SignRepoSnafu)?; let metadata_dir = &self.outdir.join("metadata"); signed_repo .write(metadata_dir) + .await .context(error::WriteRepoSnafu { directory: metadata_dir, })?; diff --git a/tuftool/src/update.rs b/tuftool/src/update.rs index 84e5a9a3..c3c7b76a 100644 --- a/tuftool/src/update.rs +++ b/tuftool/src/update.rs @@ -9,7 +9,6 @@ use crate::source::parse_key_source; use chrono::{DateTime, Utc}; use clap::Parser; use snafu::{OptionExt, ResultExt}; -use std::fs::File; use std::num::{NonZeroU64, NonZeroUsize}; use std::path::{Path, PathBuf}; use tough::editor::signed::PathExists; @@ -108,7 +107,7 @@ WARNING: `--allow-expired-repo` was passed; this is unsafe and will not establis } impl UpdateArgs { - pub(crate) fn run(&self) -> Result<()> { + pub(crate) async fn run(&self) -> Result<()> { let expiration_enforcement = if self.allow_expired_repo { expired_repo_warning(&self.outdir); ExpirationEnforcement::Unsafe @@ -116,20 +115,25 @@ impl UpdateArgs { ExpirationEnforcement::Safe }; let repository = RepositoryLoader::new( - File::open(&self.root).context(error::OpenRootSnafu { path: &self.root })?, + &tokio::fs::read(&self.root) + .await + .context(error::OpenRootSnafu { path: &self.root })?, self.metadata_base_url.clone(), Url::parse(UNUSED_URL).context(error::UrlParseSnafu { url: UNUSED_URL })?, ) .expiration_enforcement(expiration_enforcement) .load() + .await .context(error::RepoLoadSnafu)?; self.update_metadata( RepositoryEditor::from_repo(&self.root, repository) + .await .context(error::EditorFromRepoSnafu { path: &self.root })?, ) + .await } - fn update_metadata(&self, mut editor: RepositoryEditor) -> Result<()> { + async fn update_metadata(&self, mut editor: RepositoryEditor) -> Result<()> { let mut keys = Vec::new(); for source in &self.keys { let key_source = parse_key_source(source)?; @@ -157,7 +161,7 @@ impl UpdateArgs { .context(error::InitializeThreadPoolSnafu)?; } - let new_targets = build_targets(targets_indir, self.follow)?; + let new_targets = build_targets(targets_indir, self.follow).await?; for (target_name, target) in new_targets { editor @@ -170,6 +174,7 @@ impl UpdateArgs { if self.role.is_some() && self.indir.is_some() { editor .sign_targets_editor(&keys) + .await .context(error::DelegationStructureSnafu)? .update_delegated_targets( self.role.as_ref().context(error::MissingSnafu { @@ -182,19 +187,21 @@ impl UpdateArgs { })? .as_str(), ) + .await .context(error::DelegateeNotFoundSnafu { role: self.role.as_ref().unwrap().clone(), })?; } // Sign the repo - let signed_repo = editor.sign(&keys).context(error::SignRepoSnafu)?; + let signed_repo = editor.sign(&keys).await.context(error::SignRepoSnafu)?; // Symlink any targets that were added if let Some(ref targets_indir) = self.targets_indir { let targets_outdir = &self.outdir.join("targets"); signed_repo .link_targets(targets_indir, targets_outdir, self.target_path_exists) + .await .context(error::LinkTargetsSnafu { indir: &targets_indir, outdir: targets_outdir, @@ -205,6 +212,7 @@ impl UpdateArgs { let metadata_dir = &self.outdir.join("metadata"); signed_repo .write(metadata_dir) + .await .context(error::WriteRepoSnafu { directory: metadata_dir, })?; diff --git a/tuftool/src/update_targets.rs b/tuftool/src/update_targets.rs index 70d78215..3221a619 100644 --- a/tuftool/src/update_targets.rs +++ b/tuftool/src/update_targets.rs @@ -67,15 +67,16 @@ pub(crate) struct UpdateTargetsArgs { } impl UpdateTargetsArgs { - pub(crate) fn run(&self, role: &str) -> Result<()> { - let repository = load_metadata_repo(&self.root, self.metadata_base_url.clone())?; + pub(crate) async fn run(&self, role: &str) -> Result<()> { + let repository = load_metadata_repo(&self.root, self.metadata_base_url.clone()).await?; self.update_targets( TargetsEditor::from_repo(repository, role) .context(error::EditorFromRepoSnafu { path: &self.root })?, ) + .await } - fn update_targets(&self, mut editor: TargetsEditor) -> Result<()> { + async fn update_targets(&self, mut editor: TargetsEditor) -> Result<()> { let mut keys = Vec::new(); for source in &self.keys { let key_source = parse_key_source(source)?; @@ -95,7 +96,7 @@ impl UpdateTargetsArgs { .context(error::InitializeThreadPoolSnafu)?; } - let new_targets = build_targets(targets_indir, self.follow)?; + let new_targets = build_targets(targets_indir, self.follow).await?; for (target_name, target) in new_targets { editor @@ -105,13 +106,14 @@ impl UpdateTargetsArgs { }; // Sign the role - let signed_role = editor.sign(&keys).context(error::SignRepoSnafu)?; + let signed_role = editor.sign(&keys).await.context(error::SignRepoSnafu)?; // Copy any targets that were added if let Some(ref targets_indir) = self.targets_indir { let targets_outdir = &self.outdir.join("targets"); signed_role .copy_targets(targets_indir, targets_outdir, self.target_path_exists) + .await .context(error::LinkTargetsSnafu { indir: &targets_indir, outdir: targets_outdir, @@ -122,6 +124,7 @@ impl UpdateTargetsArgs { let metadata_dir = &self.outdir.join("metadata"); signed_role .write(metadata_dir, false) + .await .context(error::WriteRepoSnafu { directory: metadata_dir, })?; diff --git a/tuftool/tests/create_command.rs b/tuftool/tests/create_command.rs index 74093d26..6f5a8183 100644 --- a/tuftool/tests/create_command.rs +++ b/tuftool/tests/create_command.rs @@ -5,14 +5,13 @@ mod test_utils; use assert_cmd::Command; use chrono::{Duration, Utc}; -use std::fs::File; use tempfile::TempDir; use test_utils::dir_url; use tough::{RepositoryLoader, TargetName}; -#[test] +#[tokio::test] // Ensure we can read a repo created by the `tuftool` binary using the `tough` library -fn create_command() { +async fn create_command() { let timestamp_expiration = Utc::now().checked_add_signed(Duration::days(3)).unwrap(); let timestamp_version: u64 = 1234; let snapshot_expiration = Utc::now().checked_add_signed(Duration::days(21)).unwrap(); @@ -57,27 +56,28 @@ fn create_command() { // Load our newly created repo let repo = RepositoryLoader::new( - File::open(root_json).unwrap(), + &tokio::fs::read(root_json).await.unwrap(), dir_url(repo_dir.path().join("metadata")), dir_url(repo_dir.path().join("targets")), ) .load() + .await .unwrap(); // Ensure we can read the targets let file1 = TargetName::new("file1.txt").unwrap(); assert_eq!( - test_utils::read_to_end(repo.read_target(&file1).unwrap().unwrap()), + test_utils::read_to_end(repo.read_target(&file1).await.unwrap().unwrap()).await, &b"This is an example target file."[..] ); let file2 = TargetName::new("file2.txt").unwrap(); assert_eq!( - test_utils::read_to_end(repo.read_target(&file2).unwrap().unwrap()), + test_utils::read_to_end(repo.read_target(&file2).await.unwrap().unwrap()).await, &b"This is an another example target file."[..] ); let file3 = TargetName::new("file3.txt").unwrap(); assert_eq!( - test_utils::read_to_end(repo.read_target(&file3).unwrap().unwrap()), + test_utils::read_to_end(repo.read_target(&file3).await.unwrap().unwrap()).await, &b"This is role1's target file."[..] ); diff --git a/tuftool/tests/create_repository_integration.rs b/tuftool/tests/create_repository_integration.rs index 73d61aeb..64668fdf 100644 --- a/tuftool/tests/create_repository_integration.rs +++ b/tuftool/tests/create_repository_integration.rs @@ -5,7 +5,6 @@ mod test_utils; use assert_cmd::Command; use chrono::{Duration, Utc}; use std::env; -use std::fs::File; use tempfile::TempDir; use test_utils::dir_url; use tough::{RepositoryLoader, TargetName}; @@ -103,7 +102,7 @@ fn sign_root_json(key: &str, root_json: &str) { .success(); } -fn create_repository(root_key: &str, auto_generate: bool) { +async fn create_repository(root_key: &str, auto_generate: bool) { // create a root.json file to create TUF repository metadata let root_json_dir = TempDir::new().unwrap(); let root_json = root_json_dir.path().join("root.json"); @@ -157,27 +156,28 @@ fn create_repository(root_key: &str, auto_generate: bool) { // Load our newly created repo let repo = RepositoryLoader::new( - File::open(root_json).unwrap(), + &tokio::fs::read(root_json).await.unwrap(), dir_url(repo_dir.path().join("metadata")), dir_url(repo_dir.path().join("targets")), ) .load() + .await .unwrap(); // Ensure we can read the targets let file1 = TargetName::new("file1.txt").unwrap(); assert_eq!( - test_utils::read_to_end(repo.read_target(&file1).unwrap().unwrap()), + test_utils::read_to_end(repo.read_target(&file1).await.unwrap().unwrap()).await, &b"This is an example target file."[..] ); let file2 = TargetName::new("file2.txt").unwrap(); assert_eq!( - test_utils::read_to_end(repo.read_target(&file2).unwrap().unwrap()), + test_utils::read_to_end(repo.read_target(&file2).await.unwrap().unwrap()).await, &b"This is an another example target file."[..] ); let file3 = TargetName::new("file3.txt").unwrap(); assert_eq!( - test_utils::read_to_end(repo.read_target(&file3).unwrap().unwrap()), + test_utils::read_to_end(repo.read_target(&file3).await.unwrap().unwrap()).await, &b"This is role1's target file."[..] ); @@ -212,28 +212,28 @@ fn create_repository(root_key: &str, auto_generate: bool) { root_json_dir.close().unwrap(); } -#[test] +#[tokio::test] #[cfg_attr(not(feature = "integ"), ignore)] // Ensure we can use local rsa key to create and sign a repo created by the `tuftool` binary using the `tough` library -fn create_repository_local_key() { +async fn create_repository_local_key() { let root_key_dir = TempDir::new().unwrap(); let root_key_path = root_key_dir.path().join("local_key.pem"); let root_key = &format!("file://{}", root_key_path.to_str().unwrap()); - create_repository(root_key, true); + create_repository(root_key, true).await; } -#[test] +#[tokio::test] #[cfg_attr(not(feature = "integ"), ignore)] // Ensure we can use ssm key to create and sign a repo created by the `tuftool` binary using the `tough` library -fn create_repository_ssm_key() { +async fn create_repository_ssm_key() { let root_key = &format!("aws-ssm://{}/tough-integ/key-a", get_profile()); - create_repository(root_key, true); + create_repository(root_key, true).await; } -#[test] +#[tokio::test] #[cfg_attr(not(feature = "integ"), ignore)] // Ensure we can use kms key to create and sign a repo created by the `tuftool` binary using the `tough` library -fn create_repository_kms_key() { +async fn create_repository_kms_key() { let root_key = &format!("aws-kms://{}/alias/tough-integ/key-a", get_profile()); - create_repository(root_key, false); + create_repository(root_key, false).await; } diff --git a/tuftool/tests/delegation_commands.rs b/tuftool/tests/delegation_commands.rs index a2a2a728..8281be8b 100644 --- a/tuftool/tests/delegation_commands.rs +++ b/tuftool/tests/delegation_commands.rs @@ -5,7 +5,6 @@ mod test_utils; use assert_cmd::Command; use chrono::{Duration, Utc}; -use std::fs::File; use std::path::Path; use tempfile::TempDir; use test_utils::dir_url; @@ -54,10 +53,10 @@ fn create_repo>(repo_dir: P) { .success(); } -#[test] +#[tokio::test] // Ensure we can create a role, add the role to parent metadata, and sign repo // Structure targets -> A -> B -fn create_add_role_command() { +async fn create_add_role_command() { let root_json = test_utils::test_data().join("simple-rsa").join("root.json"); let root_key = test_utils::test_data().join("snakeoil.pem"); let targets_key = test_utils::test_data().join("targetskey"); @@ -144,11 +143,12 @@ fn create_add_role_command() { let updated_metadata_base_url = &dir_url(new_repo_dir.path().join("metadata")); let updated_targets_base_url = &dir_url(new_repo_dir.path().join("targets")); let repo = RepositoryLoader::new( - File::open(&root_json).unwrap(), + &tokio::fs::read(&root_json).await.unwrap(), updated_metadata_base_url.clone(), updated_targets_base_url.clone(), ) .load() + .await .unwrap(); // Make sure `A` is added as a role assert!(repo.delegated_role("A").is_some()); @@ -244,19 +244,20 @@ fn create_add_role_command() { // Load the updated repo let repo = RepositoryLoader::new( - File::open(root_json).unwrap(), + &tokio::fs::read(root_json).await.unwrap(), dir_url(update_out.path().join("metadata")), dir_url(update_out.path().join("targets")), ) .load() + .await .unwrap(); // Make sure `B` is added as a role assert!(repo.delegated_role("B").is_some()); } -#[test] +#[tokio::test] // Ensure we can update targets of delegated roles -fn update_target_command() { +async fn update_target_command() { let root_json = test_utils::test_data().join("simple-rsa").join("root.json"); let root_key = test_utils::test_data().join("snakeoil.pem"); let targets_key = test_utils::test_data().join("targetskey"); @@ -415,25 +416,26 @@ fn update_target_command() { // Load the updated repo let repo = RepositoryLoader::new( - File::open(root_json).unwrap(), + &tokio::fs::read(root_json).await.unwrap(), dir_url(update_out.path().join("metadata")), dir_url(update_out.path().join("targets")), ) .load() + .await .unwrap(); // Make sure we can read new target let file4 = TargetName::new("file4.txt").unwrap(); assert_eq!( - test_utils::read_to_end(repo.read_target(&file4).unwrap().unwrap()), + test_utils::read_to_end(repo.read_target(&file4).await.unwrap().unwrap()).await, &b"This is an example target file."[..] ); } -#[test] +#[tokio::test] // Ensure we can add keys to A and B // Adds new key to A and signs with it -fn add_key_command() { +async fn add_key_command() { let root_json = test_utils::test_data().join("simple-rsa").join("root.json"); let root_key = test_utils::test_data().join("snakeoil.pem"); let targets_key = test_utils::test_data().join("targetskey"); @@ -669,11 +671,12 @@ fn add_key_command() { // Load the updated repo let _repo = RepositoryLoader::new( - File::open(root_json).unwrap(), + &tokio::fs::read(root_json).await.unwrap(), dir_url(update_out.path().join("metadata")), dir_url(update_out.path().join("targets")), ) .load() + .await .unwrap(); } @@ -881,9 +884,9 @@ fn remove_key_command() { .failure(); } -#[test] +#[tokio::test] // Ensure we can remove a role -fn remove_role_command() { +async fn remove_role_command() { let root_json = test_utils::test_data().join("simple-rsa").join("root.json"); let root_key = test_utils::test_data().join("snakeoil.pem"); let targets_key = test_utils::test_data().join("targetskey"); @@ -966,11 +969,12 @@ fn remove_role_command() { let updated_metadata_base_url = dir_url(new_repo_dir.path().join("metadata")); let updated_targets_base_url = dir_url(new_repo_dir.path().join("targets")); let repo = RepositoryLoader::new( - File::open(&root_json).unwrap(), + &tokio::fs::read(&root_json).await.unwrap(), updated_metadata_base_url.clone(), updated_targets_base_url, ) .load() + .await .unwrap(); // Make sure `A` is added as a role assert!(repo.delegated_role("A").is_some()); @@ -1145,20 +1149,21 @@ fn remove_role_command() { // Load the updated repo let repo = RepositoryLoader::new( - File::open(root_json).unwrap(), + &tokio::fs::read(root_json).await.unwrap(), dir_url(update_out.path().join("metadata")), dir_url(update_out.path().join("targets")), ) .load() + .await .unwrap(); // Make sure `B` is removed assert!(repo.delegated_role("B").is_none()); } -#[test] +#[tokio::test] // Ensure we can remove a role -fn remove_role_recursive_command() { +async fn remove_role_recursive_command() { let root_json = test_utils::test_data().join("simple-rsa").join("root.json"); let root_key = test_utils::test_data().join("snakeoil.pem"); let targets_key = test_utils::test_data().join("targetskey"); @@ -1240,11 +1245,12 @@ fn remove_role_recursive_command() { // Load the updated repo let updated_metadata_base_url = &dir_url(new_repo_dir.path().join("metadata")); let repo = RepositoryLoader::new( - File::open(&root_json).unwrap(), + &tokio::fs::read(&root_json).await.unwrap(), updated_metadata_base_url.clone(), dir_url(new_repo_dir.path().join("targets")), ) .load() + .await .unwrap(); // Make sure `A` is added as a role assert!(repo.delegated_role("A").is_some()); @@ -1420,11 +1426,12 @@ fn remove_role_recursive_command() { // Load the updated repo let repo = RepositoryLoader::new( - File::open(root_json).unwrap(), + &tokio::fs::read(root_json).await.unwrap(), dir_url(update_out.path().join("metadata")), dir_url(update_out.path().join("targets")), ) .load() + .await .unwrap(); // Make sure `A` and `B` are removed @@ -1432,10 +1439,10 @@ fn remove_role_recursive_command() { assert!(repo.delegated_role("B").is_none()); } -#[test] +#[tokio::test] /// Ensure we that we percent encode path traversal characters when adding a role name such as /// `../../strange/role/../name` and that we don't write files in unexpected places. -fn dubious_role_name() { +async fn dubious_role_name() { let dubious_role_name = "../../strange/role/../name"; let dubious_name_encoded = "..%2F..%2Fstrange%2Frole%2F..%2Fname"; let funny_role_name = "../🍺/( ͡° ͜ʖ ͡°)"; @@ -1527,11 +1534,12 @@ fn dubious_role_name() { let updated_metadata_base_url = &dir_url(new_repo_dir.path().join("metadata")); let updated_targets_base_url = &dir_url(new_repo_dir.path().join("targets")); let repo = RepositoryLoader::new( - File::open(&root_json).unwrap(), + &tokio::fs::read(&root_json).await.unwrap(), updated_metadata_base_url.clone(), updated_targets_base_url.clone(), ) .load() + .await .unwrap(); // Make sure `A` is added as a role assert!(repo.delegated_role(dubious_role_name).is_some()); @@ -1639,11 +1647,12 @@ fn dubious_role_name() { // Load the updated repo let repo = RepositoryLoader::new( - File::open(root_json).unwrap(), + &tokio::fs::read(root_json).await.unwrap(), dir_url(update_out.path().join("metadata")), dir_url(update_out.path().join("targets")), ) .load() + .await .unwrap(); // Make sure `B` is added as a role diff --git a/tuftool/tests/root_command.rs b/tuftool/tests/root_command.rs index f7b3f2d7..288a08c5 100644 --- a/tuftool/tests/root_command.rs +++ b/tuftool/tests/root_command.rs @@ -284,8 +284,8 @@ fn create_invalid_root() { .failure(); } -#[test] -fn cross_sign_root() { +#[tokio::test] +async fn cross_sign_root() { let out_dir = TempDir::new().unwrap(); let old_root_json = test_utils::test_data() .join("cross-sign-root") @@ -299,6 +299,7 @@ fn cross_sign_root() { }; let old_key_id = old_key_source .as_sign() + .await .ok() .unwrap() .tuf_key() diff --git a/tuftool/tests/test_utils.rs b/tuftool/tests/test_utils.rs index be8c22dd..931d2bc9 100644 --- a/tuftool/tests/test_utils.rs +++ b/tuftool/tests/test_utils.rs @@ -3,8 +3,8 @@ use assert_cmd::Command; use chrono::{Duration, Utc}; -use std::io::Read; use std::path::{Path, PathBuf}; +use tough::IntoVec; use url::Url; /// Utilities for tests. Not every test module uses every function, so we suppress unused warnings. @@ -23,12 +23,14 @@ pub fn dir_url>(path: P) -> Url { Url::from_directory_path(path).unwrap() } -/// Returns a vector of bytes from any object with the Read trait +/// Returns a vector of bytes from any stream of byte results #[allow(unused)] -pub fn read_to_end(mut reader: R) -> Vec { - let mut v = Vec::new(); - reader.read_to_end(&mut v).unwrap(); - v +pub async fn read_to_end(mut stream: S) -> Vec +where + E: std::fmt::Debug, + S: IntoVec, +{ + stream.into_vec().await.unwrap() } /// Creates a repository with expired timestamp metadata. diff --git a/tuftool/tests/update_command.rs b/tuftool/tests/update_command.rs index b0fcc87c..026b35e8 100644 --- a/tuftool/tests/update_command.rs +++ b/tuftool/tests/update_command.rs @@ -6,7 +6,6 @@ mod test_utils; use assert_cmd::assert::Assert; use assert_cmd::Command; use chrono::{DateTime, Duration, Utc}; -use std::fs::File; use std::path::Path; use tempfile::TempDir; use test_utils::dir_url; @@ -55,9 +54,9 @@ fn create_repo>(repo_dir: P) { .success(); } -#[test] +#[tokio::test] // Ensure we can read a repo that has had its metadata updated by `tuftool create` -fn update_command_without_new_targets() { +async fn update_command_without_new_targets() { let root_json = test_utils::test_data().join("simple-rsa").join("root.json"); let root_key = test_utils::test_data().join("snakeoil.pem"); let repo_dir = TempDir::new().unwrap(); @@ -106,11 +105,12 @@ fn update_command_without_new_targets() { // Load the updated repo let repo = RepositoryLoader::new( - File::open(root_json).unwrap(), + &tokio::fs::read(root_json).await.unwrap(), dir_url(update_out.path().join("metadata")), dir_url(update_out.path().join("targets")), ) .load() + .await .unwrap(); // Ensure all the existing targets are accounted for @@ -125,10 +125,10 @@ fn update_command_without_new_targets() { assert_eq!(repo.timestamp().signed.expires, new_timestamp_expiration); } -#[test] +#[tokio::test] // Ensure we can read a repo that has had its metadata and targets updated // by `tuftool create` -fn update_command_with_new_targets() { +async fn update_command_with_new_targets() { let root_json = test_utils::test_data().join("simple-rsa").join("root.json"); let root_key = test_utils::test_data().join("snakeoil.pem"); let repo_dir = TempDir::new().unwrap(); @@ -180,11 +180,12 @@ fn update_command_with_new_targets() { // Load the updated repo. let repo = RepositoryLoader::new( - File::open(root_json).unwrap(), + &tokio::fs::read(root_json).await.unwrap(), dir_url(update_out.path().join("metadata")), dir_url(update_out.path().join("targets")), ) .load() + .await .unwrap(); // Ensure all the targets (new and existing) are accounted for @@ -193,17 +194,17 @@ fn update_command_with_new_targets() { // Ensure we can read the newly added targets let file4 = TargetName::new("file4.txt").unwrap(); assert_eq!( - test_utils::read_to_end(repo.read_target(&file4).unwrap().unwrap()), + test_utils::read_to_end(repo.read_target(&file4).await.unwrap().unwrap()).await, &b"This is an example target file."[..] ); let file5 = TargetName::new("file5.txt").unwrap(); assert_eq!( - test_utils::read_to_end(repo.read_target(&file5).unwrap().unwrap()), + test_utils::read_to_end(repo.read_target(&file5).await.unwrap().unwrap()).await, &b"This is another example target file."[..] ); let file6 = TargetName::new("file6.txt").unwrap(); assert_eq!( - test_utils::read_to_end(repo.read_target(&file6).unwrap().unwrap()), + test_utils::read_to_end(repo.read_target(&file6).await.unwrap().unwrap()).await, &b"This is yet another example target file."[..] ); @@ -357,9 +358,9 @@ fn update_command_expired_repo_fail() { update_expected.0.failure(); } -#[test] +#[tokio::test] // Ensure we can update a repo that has its metadata expired but --allow-expired-repo flag is passed -fn update_command_expired_repo_allow() { +async fn update_command_expired_repo_allow() { let outdir = TempDir::new().unwrap(); let repo_dir = TempDir::new().unwrap(); // Create a expired repo using tuftool and the reference tuf implementation data @@ -370,11 +371,12 @@ fn update_command_expired_repo_allow() { // Load the updated repo let root_json = test_utils::test_data().join("simple-rsa").join("root.json"); let repo = RepositoryLoader::new( - File::open(root_json).unwrap(), + &tokio::fs::read(root_json).await.unwrap(), dir_url(outdir.path().join("metadata")), dir_url(outdir.path().join("targets")), ) .load() + .await .unwrap(); // Ensure all the existing targets are accounted for assert_eq!(repo.targets().signed.targets.len(), 3);