diff --git a/object_store/src/aws/checksum.rs b/object_store/src/aws/checksum.rs index a50bd2d18b9c..d15bbf08df69 100644 --- a/object_store/src/aws/checksum.rs +++ b/object_store/src/aws/checksum.rs @@ -16,7 +16,6 @@ // under the License. use crate::config::Parse; -use ring::digest::{self, digest as ring_digest}; use std::str::FromStr; #[allow(non_camel_case_types)] @@ -27,20 +26,6 @@ pub enum Checksum { SHA256, } -impl Checksum { - pub(super) fn digest(&self, bytes: &[u8]) -> Vec { - match self { - Self::SHA256 => ring_digest(&digest::SHA256, bytes).as_ref().to_owned(), - } - } - - pub(super) fn header_name(&self) -> &'static str { - match self { - Self::SHA256 => "x-amz-checksum-sha256", - } - } -} - impl std::fmt::Display for Checksum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self { diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index 838bef8ac23b..c1789ed143e4 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -35,7 +35,8 @@ use crate::client::GetOptionsExt; use crate::multipart::PartId; use crate::path::DELIMITER; use crate::{ - ClientOptions, GetOptions, ListResult, MultipartId, Path, PutResult, Result, RetryConfig, + ClientOptions, GetOptions, ListResult, MultipartId, Path, PutPayload, PutResult, Result, + RetryConfig, }; use async_trait::async_trait; use base64::prelude::BASE64_STANDARD; @@ -51,11 +52,14 @@ use reqwest::{ header::{CONTENT_LENGTH, CONTENT_TYPE}, Client as ReqwestClient, Method, RequestBuilder, Response, }; +use ring::digest; +use ring::digest::Context; use serde::{Deserialize, Serialize}; use snafu::{ResultExt, Snafu}; use std::sync::Arc; const VERSION_HEADER: &str = "x-amz-version-id"; +const SHA256_CHECKSUM: &str = "x-amz-checksum-sha256"; /// A specialized `Error` for object store-related errors #[derive(Debug, Snafu)] @@ -266,7 +270,8 @@ pub(crate) struct Request<'a> { path: &'a Path, config: &'a S3Config, builder: RequestBuilder, - payload_sha256: Option>, + payload_sha256: Option, + payload: Option, use_session_creds: bool, idempotent: bool, } @@ -286,7 +291,7 @@ impl<'a> Request<'a> { Self { builder, ..self } } - pub fn set_idempotent(mut self, idempotent: bool) -> Self { + pub fn idempotent(mut self, idempotent: bool) -> Self { self.idempotent = idempotent; self } @@ -301,10 +306,15 @@ impl<'a> Request<'a> { }, }; + let sha = self.payload_sha256.as_ref().map(|x| x.as_ref()); + let path = self.path.as_ref(); self.builder - .with_aws_sigv4(credential.authorizer(), self.payload_sha256.as_deref()) - .send_retry_with_idempotency(&self.config.retry_config, self.idempotent) + .with_aws_sigv4(credential.authorizer(), sha) + .retryable(&self.config.retry_config) + .idempotent(self.idempotent) + .payload(self.payload) + .send() .await .context(RetrySnafu { path }) } @@ -333,7 +343,7 @@ impl S3Client { pub fn put_request<'a>( &'a self, path: &'a Path, - bytes: Bytes, + payload: PutPayload, with_encryption_headers: bool, ) -> Request<'a> { let url = self.config.path_url(path); @@ -341,20 +351,17 @@ impl S3Client { if with_encryption_headers { builder = builder.headers(self.config.encryption_headers.clone().into()); } - let mut payload_sha256 = None; - if let Some(checksum) = self.config.checksum { - let digest = checksum.digest(&bytes); - builder = builder.header(checksum.header_name(), BASE64_STANDARD.encode(&digest)); - if checksum == Checksum::SHA256 { - payload_sha256 = Some(digest); - } - } + let mut sha256 = Context::new(&digest::SHA256); + payload.iter().for_each(|x| sha256.update(x)); + let payload_sha256 = sha256.finish(); - builder = match bytes.is_empty() { - true => builder.header(CONTENT_LENGTH, 0), // Handle empty uploads (#4514) - false => builder.body(bytes), - }; + if let Some(Checksum::SHA256) = self.config.checksum { + builder = builder.header( + "x-amz-checksum-sha256", + BASE64_STANDARD.encode(payload_sha256), + ) + } if let Some(value) = self.config.client_options.get_content_type(path) { builder = builder.header(CONTENT_TYPE, value); @@ -362,8 +369,9 @@ impl S3Client { Request { path, - builder, - payload_sha256, + builder: builder.header(CONTENT_LENGTH, payload.content_length()), + payload: Some(payload), + payload_sha256: Some(payload_sha256), config: &self.config, use_session_creds: true, idempotent: false, @@ -446,16 +454,8 @@ impl S3Client { let mut builder = self.client.request(Method::POST, url); - // Compute checksum - S3 *requires* this for DeleteObjects requests, so we default to - // their algorithm if the user hasn't specified one. - let checksum = self.config.checksum.unwrap_or(Checksum::SHA256); - let digest = checksum.digest(&body); - builder = builder.header(checksum.header_name(), BASE64_STANDARD.encode(&digest)); - let payload_sha256 = if checksum == Checksum::SHA256 { - Some(digest) - } else { - None - }; + let digest = digest::digest(&digest::SHA256, &body); + builder = builder.header(SHA256_CHECKSUM, BASE64_STANDARD.encode(digest)); // S3 *requires* DeleteObjects to include a Content-MD5 header: // https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html @@ -468,8 +468,8 @@ impl S3Client { let response = builder .header(CONTENT_TYPE, "application/xml") .body(body) - .with_aws_sigv4(credential.authorizer(), payload_sha256.as_deref()) - .send_retry_with_idempotency(&self.config.retry_config, false) + .with_aws_sigv4(credential.authorizer(), Some(digest.as_ref())) + .send_retry(&self.config.retry_config) .await .context(DeleteObjectsRequestSnafu {})? .bytes() @@ -515,6 +515,7 @@ impl S3Client { builder, path: from, config: &self.config, + payload: None, payload_sha256: None, use_session_creds: false, idempotent: false, @@ -530,7 +531,9 @@ impl S3Client { .request(Method::POST, url) .headers(self.config.encryption_headers.clone().into()) .with_aws_sigv4(credential.authorizer(), None) - .send_retry_with_idempotency(&self.config.retry_config, true) + .retryable(&self.config.retry_config) + .idempotent(true) + .send() .await .context(CreateMultipartRequestSnafu)? .bytes() @@ -548,14 +551,14 @@ impl S3Client { path: &Path, upload_id: &MultipartId, part_idx: usize, - data: Bytes, + data: PutPayload, ) -> Result { let part = (part_idx + 1).to_string(); let response = self .put_request(path, data, false) .query(&[("partNumber", &part), ("uploadId", upload_id)]) - .set_idempotent(true) + .idempotent(true) .send() .await?; @@ -573,7 +576,7 @@ impl S3Client { // If no parts were uploaded, upload an empty part // otherwise the completion request will fail let part = self - .put_part(location, &upload_id.to_string(), 0, Bytes::new()) + .put_part(location, &upload_id.to_string(), 0, PutPayload::default()) .await?; vec![part] } else { @@ -591,7 +594,9 @@ impl S3Client { .query(&[("uploadId", upload_id)]) .body(body) .with_aws_sigv4(credential.authorizer(), None) - .send_retry_with_idempotency(&self.config.retry_config, true) + .retryable(&self.config.retry_config) + .idempotent(true) + .send() .await .context(CompleteMultipartRequestSnafu)?; diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs index a7d1a9772aa1..08831fd51234 100644 --- a/object_store/src/aws/credential.rs +++ b/object_store/src/aws/credential.rs @@ -517,7 +517,9 @@ async fn instance_creds( let token_result = client .request(Method::PUT, token_url) .header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL - .send_retry_with_idempotency(retry_config, true) + .retryable(retry_config) + .idempotent(true) + .send() .await; let token = match token_result { @@ -607,7 +609,9 @@ async fn web_identity( ("Version", "2011-06-15"), ("WebIdentityToken", &token), ]) - .send_retry_with_idempotency(retry_config, true) + .retryable(retry_config) + .idempotent(true) + .send() .await? .bytes() .await?; diff --git a/object_store/src/aws/dynamo.rs b/object_store/src/aws/dynamo.rs index 2e60bbad2226..2390187e7f72 100644 --- a/object_store/src/aws/dynamo.rs +++ b/object_store/src/aws/dynamo.rs @@ -186,11 +186,7 @@ impl DynamoCommit { to: &Path, ) -> Result<()> { self.conditional_op(client, to, None, || async { - client - .copy_request(from, to) - .set_idempotent(false) - .send() - .await?; + client.copy_request(from, to).send().await?; Ok(()) }) .await diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index 16af4d3b4107..9e741c9142dd 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -29,7 +29,6 @@ //! [automatic cleanup]: https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/ use async_trait::async_trait; -use bytes::Bytes; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use reqwest::header::{HeaderName, IF_MATCH, IF_NONE_MATCH}; @@ -46,7 +45,7 @@ use crate::signer::Signer; use crate::util::STRICT_ENCODE_SET; use crate::{ Error, GetOptions, GetResult, ListResult, MultipartId, MultipartUpload, ObjectMeta, - ObjectStore, Path, PutMode, PutOptions, PutResult, Result, UploadPart, + ObjectStore, Path, PutMode, PutOptions, PutPayload, PutResult, Result, UploadPart, }; static TAGS_HEADER: HeaderName = HeaderName::from_static("x-amz-tagging"); @@ -151,15 +150,20 @@ impl Signer for AmazonS3 { #[async_trait] impl ObjectStore for AmazonS3 { - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { - let mut request = self.client.put_request(location, bytes, true); + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + let mut request = self.client.put_request(location, payload, true); let tags = opts.tags.encoded(); if !tags.is_empty() && !self.client.config.disable_tagging { request = request.header(&TAGS_HEADER, tags); } match (opts.mode, &self.client.config.conditional_put) { - (PutMode::Overwrite, _) => request.set_idempotent(true).do_put().await, + (PutMode::Overwrite, _) => request.idempotent(true).do_put().await, (PutMode::Create | PutMode::Update(_), None) => Err(Error::NotImplemented), (PutMode::Create, Some(S3ConditionalPut::ETagMatch)) => { match request.header(&IF_NONE_MATCH, "*").do_put().await { @@ -270,7 +274,7 @@ impl ObjectStore for AmazonS3 { async fn copy(&self, from: &Path, to: &Path) -> Result<()> { self.client .copy_request(from, to) - .set_idempotent(true) + .idempotent(true) .send() .await?; Ok(()) @@ -320,7 +324,7 @@ struct UploadState { #[async_trait] impl MultipartUpload for S3MultiPartUpload { - fn put_part(&mut self, data: Bytes) -> UploadPart { + fn put_part(&mut self, data: PutPayload) -> UploadPart { let idx = self.part_idx; self.part_idx += 1; let state = Arc::clone(&self.state); @@ -362,7 +366,7 @@ impl MultipartStore for AmazonS3 { path: &Path, id: &MultipartId, part_idx: usize, - data: Bytes, + data: PutPayload, ) -> Result { self.client.put_part(path, id, part_idx, data).await } @@ -385,7 +389,6 @@ impl MultipartStore for AmazonS3 { mod tests { use super::*; use crate::{client::get::GetClient, tests::*}; - use bytes::Bytes; use hyper::HeaderMap; const NON_EXISTENT_NAME: &str = "nonexistentname"; @@ -474,7 +477,7 @@ mod tests { let integration = config.build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); - let data = Bytes::from("arbitrary data"); + let data = PutPayload::from("arbitrary data"); let err = integration.put(&location, data).await.unwrap_err(); assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); @@ -531,7 +534,7 @@ mod tests { async fn s3_encryption(store: &AmazonS3) { crate::test_util::maybe_skip_integration!(); - let data = Bytes::from(vec![3u8; 1024]); + let data = PutPayload::from(vec![3u8; 1024]); let encryption_headers: HeaderMap = store.client.config.encryption_headers.clone().into(); let expected_encryption = diff --git a/object_store/src/azure/client.rs b/object_store/src/azure/client.rs index 0e6af50fbf94..d5972d0a8c16 100644 --- a/object_store/src/azure/client.rs +++ b/object_store/src/azure/client.rs @@ -27,8 +27,8 @@ use crate::multipart::PartId; use crate::path::DELIMITER; use crate::util::{deserialize_rfc1123, GetRange}; use crate::{ - ClientOptions, GetOptions, ListResult, ObjectMeta, Path, PutMode, PutOptions, PutResult, - Result, RetryConfig, + ClientOptions, GetOptions, ListResult, ObjectMeta, Path, PutMode, PutOptions, PutPayload, + PutResult, Result, RetryConfig, }; use async_trait::async_trait; use base64::prelude::BASE64_STANDARD; @@ -171,6 +171,7 @@ impl AzureConfig { struct PutRequest<'a> { path: &'a Path, config: &'a AzureConfig, + payload: PutPayload, builder: RequestBuilder, idempotent: bool, } @@ -195,8 +196,12 @@ impl<'a> PutRequest<'a> { let credential = self.config.get_credential().await?; let response = self .builder + .header(CONTENT_LENGTH, self.payload.content_length()) .with_azure_authorization(&credential, &self.config.account) - .send_retry_with_idempotency(&self.config.retry_config, self.idempotent) + .retryable(&self.config.retry_config) + .idempotent(true) + .payload(Some(self.payload)) + .send() .await .context(PutRequestSnafu { path: self.path.as_ref(), @@ -228,7 +233,7 @@ impl AzureClient { self.config.get_credential().await } - fn put_request<'a>(&'a self, path: &'a Path, bytes: Bytes) -> PutRequest<'a> { + fn put_request<'a>(&'a self, path: &'a Path, payload: PutPayload) -> PutRequest<'a> { let url = self.config.path_url(path); let mut builder = self.client.request(Method::PUT, url); @@ -237,21 +242,23 @@ impl AzureClient { builder = builder.header(CONTENT_TYPE, value); } - builder = builder - .header(CONTENT_LENGTH, HeaderValue::from(bytes.len())) - .body(bytes); - PutRequest { path, builder, + payload, config: &self.config, idempotent: false, } } /// Make an Azure PUT request - pub async fn put_blob(&self, path: &Path, bytes: Bytes, opts: PutOptions) -> Result { - let builder = self.put_request(path, bytes); + pub async fn put_blob( + &self, + path: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + let builder = self.put_request(path, payload); let builder = match &opts.mode { PutMode::Overwrite => builder.set_idempotent(true), @@ -272,11 +279,16 @@ impl AzureClient { } /// PUT a block - pub async fn put_block(&self, path: &Path, part_idx: usize, data: Bytes) -> Result { + pub async fn put_block( + &self, + path: &Path, + part_idx: usize, + payload: PutPayload, + ) -> Result { let content_id = format!("{part_idx:20}"); let block_id = BASE64_STANDARD.encode(&content_id); - self.put_request(path, data) + self.put_request(path, payload) .query(&[("comp", "block"), ("blockid", &block_id)]) .set_idempotent(true) .send() @@ -349,7 +361,9 @@ impl AzureClient { builder .with_azure_authorization(&credential, &self.config.account) - .send_retry_with_idempotency(&self.config.retry_config, true) + .retryable(&self.config.retry_config) + .idempotent(overwrite) + .send() .await .map_err(|err| err.error(STORE, from.to_string()))?; @@ -382,7 +396,9 @@ impl AzureClient { .body(body) .query(&[("restype", "service"), ("comp", "userdelegationkey")]) .with_azure_authorization(&credential, &self.config.account) - .send_retry_with_idempotency(&self.config.retry_config, true) + .retryable(&self.config.retry_config) + .idempotent(true) + .send() .await .context(DelegationKeyRequestSnafu)? .bytes() diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index 36845bd1d646..c8212a9290f5 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -615,7 +615,9 @@ impl TokenProvider for ClientSecretOAuthProvider { ("scope", AZURE_STORAGE_SCOPE), ("grant_type", "client_credentials"), ]) - .send_retry_with_idempotency(retry, true) + .retryable(retry) + .idempotent(true) + .send() .await .context(TokenRequestSnafu)? .json() @@ -797,7 +799,9 @@ impl TokenProvider for WorkloadIdentityOAuthProvider { ("scope", AZURE_STORAGE_SCOPE), ("grant_type", "client_credentials"), ]) - .send_retry_with_idempotency(retry, true) + .retryable(retry) + .idempotent(true) + .send() .await .context(TokenRequestSnafu)? .json() diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 5d3a405ccc93..8dc52422b7de 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -27,10 +27,9 @@ use crate::{ path::Path, signer::Signer, GetOptions, GetResult, ListResult, MultipartId, MultipartUpload, ObjectMeta, ObjectStore, - PutOptions, PutResult, Result, UploadPart, + PutOptions, PutPayload, PutResult, Result, UploadPart, }; use async_trait::async_trait; -use bytes::Bytes; use futures::stream::BoxStream; use reqwest::Method; use std::fmt::Debug; @@ -87,8 +86,13 @@ impl std::fmt::Display for MicrosoftAzure { #[async_trait] impl ObjectStore for MicrosoftAzure { - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { - self.client.put_blob(location, bytes, opts).await + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.client.put_blob(location, payload, opts).await } async fn put_multipart(&self, location: &Path) -> Result> { @@ -203,7 +207,7 @@ struct UploadState { #[async_trait] impl MultipartUpload for AzureMultiPartUpload { - fn put_part(&mut self, data: Bytes) -> UploadPart { + fn put_part(&mut self, data: PutPayload) -> UploadPart { let idx = self.part_idx; self.part_idx += 1; let state = Arc::clone(&self.state); @@ -240,7 +244,7 @@ impl MultipartStore for MicrosoftAzure { path: &Path, _: &MultipartId, part_idx: usize, - data: Bytes, + data: PutPayload, ) -> Result { self.client.put_block(path, part_idx, data).await } @@ -265,6 +269,7 @@ impl MultipartStore for MicrosoftAzure { mod tests { use super::*; use crate::tests::*; + use bytes::Bytes; #[tokio::test] async fn azure_blob_test() { @@ -309,7 +314,7 @@ mod tests { let data = Bytes::from("hello world"); let path = Path::from("file.txt"); - integration.put(&path, data.clone()).await.unwrap(); + integration.put(&path, data.clone().into()).await.unwrap(); let signed = integration .signed_url(Method::GET, &path, Duration::from_secs(60)) diff --git a/object_store/src/buffered.rs b/object_store/src/buffered.rs index de6d4eb1bb9c..8750b74e9dc6 100644 --- a/object_store/src/buffered.rs +++ b/object_store/src/buffered.rs @@ -391,7 +391,7 @@ mod tests { const BYTES: usize = 4096; let data: Bytes = b"12345678".iter().cycle().copied().take(BYTES).collect(); - store.put(&existent, data.clone()).await.unwrap(); + store.put(&existent, data.clone().into()).await.unwrap(); let meta = store.head(&existent).await.unwrap(); diff --git a/object_store/src/chunked.rs b/object_store/src/chunked.rs index 6db7f4b35e24..9abe49dbfce9 100644 --- a/object_store/src/chunked.rs +++ b/object_store/src/chunked.rs @@ -27,11 +27,11 @@ use futures::stream::BoxStream; use futures::StreamExt; use crate::path::Path; -use crate::Result; use crate::{ GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta, ObjectStore, PutOptions, PutResult, }; +use crate::{PutPayload, Result}; /// Wraps a [`ObjectStore`] and makes its get response return chunks /// in a controllable manner. @@ -62,8 +62,13 @@ impl Display for ChunkedStore { #[async_trait] impl ObjectStore for ChunkedStore { - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { - self.inner.put_opts(location, bytes, opts).await + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.inner.put_opts(location, payload, opts).await } async fn put_multipart(&self, location: &Path) -> Result> { @@ -176,10 +181,7 @@ mod tests { async fn test_chunked_basic() { let location = Path::parse("test").unwrap(); let store: Arc = Arc::new(InMemory::new()); - store - .put(&location, Bytes::from(vec![0; 1001])) - .await - .unwrap(); + store.put(&location, vec![0; 1001].into()).await.unwrap(); for chunk_size in [10, 20, 31] { let store = ChunkedStore::new(Arc::clone(&store), chunk_size); diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs index f3fa7153e930..422b7070070d 100644 --- a/object_store/src/client/retry.rs +++ b/object_store/src/client/retry.rs @@ -18,10 +18,10 @@ //! A shared HTTP client implementation incorporating retries use crate::client::backoff::{Backoff, BackoffConfig}; +use crate::PutPayload; use futures::future::BoxFuture; -use futures::FutureExt; use reqwest::header::LOCATION; -use reqwest::{Response, StatusCode}; +use reqwest::{Client, Request, Response, StatusCode}; use snafu::Error as SnafuError; use snafu::Snafu; use std::time::{Duration, Instant}; @@ -166,26 +166,54 @@ impl Default for RetryConfig { } } -fn send_retry_impl( - builder: reqwest::RequestBuilder, - config: &RetryConfig, - is_idempotent: Option, -) -> BoxFuture<'static, Result> { - let mut backoff = Backoff::new(&config.backoff); - let max_retries = config.max_retries; - let retry_timeout = config.retry_timeout; +pub struct RetryableRequest { + client: Client, + request: Request, - let (client, req) = builder.build_split(); - let req = req.expect("request must be valid"); - let is_idempotent = is_idempotent.unwrap_or(req.method().is_safe()); + max_retries: usize, + retry_timeout: Duration, + backoff: Backoff, - async move { + idempotent: Option, + payload: Option, +} + +impl RetryableRequest { + /// Set whether this request is idempotent + pub fn idempotent(self, idempotent: bool) -> Self { + Self { + idempotent: Some(idempotent), + ..self + } + } + + /// Provide a [`PutPayload`] + pub fn payload(self, payload: Option) -> Self { + Self { payload, ..self } + } + + pub async fn send(self) -> Result { + let max_retries = self.max_retries; + let retry_timeout = self.retry_timeout; let mut retries = 0; let now = Instant::now(); + let mut backoff = self.backoff; + let is_idempotent = self + .idempotent + .unwrap_or_else(|| self.request.method().is_safe()); + loop { - let s = req.try_clone().expect("request body must be cloneable"); - match client.execute(s).await { + let mut s = self + .request + .try_clone() + .expect("request body must be cloneable"); + + if let Some(x) = &self.payload { + *s.body_mut() = Some(x.body()); + } + + match self.client.execute(s).await { Ok(r) => match r.error_for_status_ref() { Ok(_) if r.status().is_success() => return Ok(r), Ok(r) if r.status() == StatusCode::NOT_MODIFIED => { @@ -195,47 +223,44 @@ fn send_retry_impl( }) } Ok(r) => { - let is_bare_redirect = r.status().is_redirection() && !r.headers().contains_key(LOCATION); + let is_bare_redirect = + r.status().is_redirection() && !r.headers().contains_key(LOCATION); return match is_bare_redirect { true => Err(Error::BareRedirect), // Not actually sure if this is reachable, but here for completeness false => Err(Error::Client { body: None, status: r.status(), - }) - } + }), + }; } Err(e) => { let status = r.status(); if retries == max_retries || now.elapsed() > retry_timeout - || !status.is_server_error() { - + || !status.is_server_error() + { return Err(match status.is_client_error() { true => match r.text().await { - Ok(body) => { - Error::Client { - body: Some(body).filter(|b| !b.is_empty()), - status, - } - } - Err(e) => { - Error::Reqwest { - retries, - max_retries, - elapsed: now.elapsed(), - retry_timeout, - source: e, - } - } - } + Ok(body) => Error::Client { + body: Some(body).filter(|b| !b.is_empty()), + status, + }, + Err(e) => Error::Reqwest { + retries, + max_retries, + elapsed: now.elapsed(), + retry_timeout, + source: e, + }, + }, false => Error::Reqwest { retries, max_retries, elapsed: now.elapsed(), retry_timeout, source: e, - } + }, }); } @@ -251,13 +276,13 @@ fn send_retry_impl( tokio::time::sleep(sleep).await; } }, - Err(e) => - { + Err(e) => { let mut do_retry = false; if e.is_connect() || e.is_body() || (e.is_request() && !e.is_timeout()) - || (is_idempotent && e.is_timeout()) { + || (is_idempotent && e.is_timeout()) + { do_retry = true } else { let mut source = e.source(); @@ -267,7 +292,7 @@ fn send_retry_impl( || e.is_incomplete_message() || e.is_body_write_aborted() || (is_idempotent && e.is_timeout()); - break + break; } if let Some(e) = e.downcast_ref::() { if e.kind() == std::io::ErrorKind::TimedOut { @@ -276,9 +301,9 @@ fn send_retry_impl( do_retry = matches!( e.kind(), std::io::ErrorKind::ConnectionReset - | std::io::ErrorKind::ConnectionAborted - | std::io::ErrorKind::BrokenPipe - | std::io::ErrorKind::UnexpectedEof + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::BrokenPipe + | std::io::ErrorKind::UnexpectedEof ); } break; @@ -287,17 +312,14 @@ fn send_retry_impl( } } - if retries == max_retries - || now.elapsed() > retry_timeout - || !do_retry { - + if retries == max_retries || now.elapsed() > retry_timeout || !do_retry { return Err(Error::Reqwest { retries, max_retries, elapsed: now.elapsed(), retry_timeout, source: e, - }) + }); } let sleep = backoff.next(); retries += 1; @@ -313,39 +335,39 @@ fn send_retry_impl( } } } - .boxed() } pub trait RetryExt { + /// Return a [`RetryableRequest`] + fn retryable(self, config: &RetryConfig) -> RetryableRequest; + /// Dispatch a request with the given retry configuration /// /// # Panic /// /// This will panic if the request body is a stream fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result>; - - /// Dispatch a request with the given retry configuration and idempotency - /// - /// # Panic - /// - /// This will panic if the request body is a stream - fn send_retry_with_idempotency( - self, - config: &RetryConfig, - is_idempotent: bool, - ) -> BoxFuture<'static, Result>; } impl RetryExt for reqwest::RequestBuilder { - fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result> { - send_retry_impl(self, config, None) + fn retryable(self, config: &RetryConfig) -> RetryableRequest { + let (client, request) = self.build_split(); + let request = request.expect("request must be valid"); + + RetryableRequest { + client, + request, + max_retries: config.max_retries, + retry_timeout: config.retry_timeout, + backoff: Backoff::new(&config.backoff), + idempotent: None, + payload: None, + } } - fn send_retry_with_idempotency( - self, - config: &RetryConfig, - is_idempotent: bool, - ) -> BoxFuture<'static, Result> { - send_retry_impl(self, config, Some(is_idempotent)) + + fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result> { + let request = self.retryable(config); + Box::pin(async move { request.send().await }) } } diff --git a/object_store/src/gcp/client.rs b/object_store/src/gcp/client.rs index 17404f9d5acd..c74d7abce4f2 100644 --- a/object_store/src/gcp/client.rs +++ b/object_store/src/gcp/client.rs @@ -29,13 +29,14 @@ use crate::multipart::PartId; use crate::path::{Path, DELIMITER}; use crate::util::hex_encode; use crate::{ - ClientOptions, GetOptions, ListResult, MultipartId, PutMode, PutOptions, PutResult, Result, - RetryConfig, + ClientOptions, GetOptions, ListResult, MultipartId, PutMode, PutOptions, PutPayload, PutResult, + Result, RetryConfig, }; use async_trait::async_trait; use base64::prelude::BASE64_STANDARD; use base64::Engine; -use bytes::{Buf, Bytes}; +use bytes::Buf; +use hyper::header::CONTENT_LENGTH; use percent_encoding::{percent_encode, utf8_percent_encode, NON_ALPHANUMERIC}; use reqwest::header::HeaderName; use reqwest::{header, Client, Method, RequestBuilder, Response, StatusCode}; @@ -172,6 +173,7 @@ impl GoogleCloudStorageConfig { pub struct PutRequest<'a> { path: &'a Path, config: &'a GoogleCloudStorageConfig, + payload: PutPayload, builder: RequestBuilder, idempotent: bool, } @@ -197,7 +199,11 @@ impl<'a> PutRequest<'a> { let response = self .builder .bearer_auth(&credential.bearer) - .send_retry_with_idempotency(&self.config.retry_config, self.idempotent) + .header(CONTENT_LENGTH, self.payload.content_length()) + .retryable(&self.config.retry_config) + .idempotent(self.idempotent) + .payload(Some(self.payload)) + .send() .await .context(PutRequestSnafu { path: self.path.as_ref(), @@ -287,7 +293,9 @@ impl GoogleCloudStorageClient { .post(&url) .bearer_auth(&credential.bearer) .json(&body) - .send_retry_with_idempotency(&self.config.retry_config, true) + .retryable(&self.config.retry_config) + .idempotent(true) + .send() .await .context(SignBlobRequestSnafu)?; @@ -315,7 +323,7 @@ impl GoogleCloudStorageClient { /// Perform a put request /// /// Returns the new ETag - pub fn put_request<'a>(&'a self, path: &'a Path, payload: Bytes) -> PutRequest<'a> { + pub fn put_request<'a>(&'a self, path: &'a Path, payload: PutPayload) -> PutRequest<'a> { let url = self.object_url(path); let content_type = self @@ -327,20 +335,24 @@ impl GoogleCloudStorageClient { let builder = self .client .request(Method::PUT, url) - .header(header::CONTENT_TYPE, content_type) - .header(header::CONTENT_LENGTH, payload.len()) - .body(payload); + .header(header::CONTENT_TYPE, content_type); PutRequest { path, builder, + payload, config: &self.config, idempotent: false, } } - pub async fn put(&self, path: &Path, data: Bytes, opts: PutOptions) -> Result { - let builder = self.put_request(path, data); + pub async fn put( + &self, + path: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + let builder = self.put_request(path, payload); let builder = match &opts.mode { PutMode::Overwrite => builder.set_idempotent(true), @@ -367,7 +379,7 @@ impl GoogleCloudStorageClient { path: &Path, upload_id: &MultipartId, part_idx: usize, - data: Bytes, + data: PutPayload, ) -> Result { let query = &[ ("partNumber", &format!("{}", part_idx + 1)), @@ -403,7 +415,9 @@ impl GoogleCloudStorageClient { .header(header::CONTENT_TYPE, content_type) .header(header::CONTENT_LENGTH, "0") .query(&[("uploads", "")]) - .send_retry_with_idempotency(&self.config.retry_config, true) + .retryable(&self.config.retry_config) + .idempotent(true) + .send() .await .context(PutRequestSnafu { path: path.as_ref(), @@ -472,7 +486,9 @@ impl GoogleCloudStorageClient { .bearer_auth(&credential.bearer) .query(&[("uploadId", upload_id)]) .body(data) - .send_retry_with_idempotency(&self.config.retry_config, true) + .retryable(&self.config.retry_config) + .idempotent(true) + .send() .await .context(CompleteMultipartRequestSnafu)?; @@ -530,8 +546,10 @@ impl GoogleCloudStorageClient { .bearer_auth(&credential.bearer) // Needed if reqwest is compiled with native-tls instead of rustls-tls // See https://github.com/apache/arrow-rs/pull/3921 - .header(header::CONTENT_LENGTH, 0) - .send_retry_with_idempotency(&self.config.retry_config, !if_not_exists) + .header(CONTENT_LENGTH, 0) + .retryable(&self.config.retry_config) + .idempotent(!if_not_exists) + .send() .await .map_err(|err| match err.status() { Some(StatusCode::PRECONDITION_FAILED) => crate::Error::AlreadyExists { diff --git a/object_store/src/gcp/credential.rs b/object_store/src/gcp/credential.rs index 158716ce4c18..ed13dd9730e7 100644 --- a/object_store/src/gcp/credential.rs +++ b/object_store/src/gcp/credential.rs @@ -623,7 +623,9 @@ impl TokenProvider for AuthorizedUserCredentials { ("client_secret", &self.client_secret), ("refresh_token", &self.refresh_token), ]) - .send_retry_with_idempotency(retry, true) + .retryable(retry) + .idempotent(true) + .send() .await .context(TokenRequestSnafu)? .json::() diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index 96afa45f2b61..149da76f559a 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -42,10 +42,9 @@ use crate::gcp::credential::GCSAuthorizer; use crate::signer::Signer; use crate::{ multipart::PartId, path::Path, GetOptions, GetResult, ListResult, MultipartId, MultipartUpload, - ObjectMeta, ObjectStore, PutOptions, PutResult, Result, UploadPart, + ObjectMeta, ObjectStore, PutOptions, PutPayload, PutResult, Result, UploadPart, }; use async_trait::async_trait; -use bytes::Bytes; use client::GoogleCloudStorageClient; use futures::stream::BoxStream; use hyper::Method; @@ -115,14 +114,14 @@ struct UploadState { #[async_trait] impl MultipartUpload for GCSMultipartUpload { - fn put_part(&mut self, data: Bytes) -> UploadPart { + fn put_part(&mut self, payload: PutPayload) -> UploadPart { let idx = self.part_idx; self.part_idx += 1; let state = Arc::clone(&self.state); Box::pin(async move { let part = state .client - .put_part(&state.path, &state.multipart_id, idx, data) + .put_part(&state.path, &state.multipart_id, idx, payload) .await?; state.parts.put(idx, part); Ok(()) @@ -148,8 +147,13 @@ impl MultipartUpload for GCSMultipartUpload { #[async_trait] impl ObjectStore for GoogleCloudStorage { - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { - self.client.put(location, bytes, opts).await + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.client.put(location, payload, opts).await } async fn put_multipart(&self, location: &Path) -> Result> { @@ -210,9 +214,9 @@ impl MultipartStore for GoogleCloudStorage { path: &Path, id: &MultipartId, part_idx: usize, - data: Bytes, + payload: PutPayload, ) -> Result { - self.client.put_part(path, id, part_idx, data).await + self.client.put_part(path, id, part_idx, payload).await } async fn complete_multipart( @@ -260,7 +264,6 @@ impl Signer for GoogleCloudStorage { #[cfg(test)] mod test { - use bytes::Bytes; use credential::DEFAULT_GCS_BASE_URL; use crate::tests::*; @@ -391,7 +394,7 @@ mod test { let integration = config.with_bucket_name(NON_EXISTENT_NAME).build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); - let data = Bytes::from("arbitrary data"); + let data = PutPayload::from("arbitrary data"); let err = integration .put(&location, data) diff --git a/object_store/src/http/client.rs b/object_store/src/http/client.rs index fdc8751c1ca1..39f68ece65a3 100644 --- a/object_store/src/http/client.rs +++ b/object_store/src/http/client.rs @@ -21,10 +21,11 @@ use crate::client::retry::{self, RetryConfig, RetryExt}; use crate::client::GetOptionsExt; use crate::path::{Path, DELIMITER}; use crate::util::deserialize_rfc1123; -use crate::{ClientOptions, GetOptions, ObjectMeta, Result}; +use crate::{ClientOptions, GetOptions, ObjectMeta, PutPayload, Result}; use async_trait::async_trait; -use bytes::{Buf, Bytes}; +use bytes::Buf; use chrono::{DateTime, Utc}; +use hyper::header::CONTENT_LENGTH; use percent_encoding::percent_decode_str; use reqwest::header::CONTENT_TYPE; use reqwest::{Method, Response, StatusCode}; @@ -156,16 +157,24 @@ impl Client { Ok(()) } - pub async fn put(&self, location: &Path, bytes: Bytes) -> Result { + pub async fn put(&self, location: &Path, payload: PutPayload) -> Result { let mut retry = false; loop { let url = self.path_url(location); - let mut builder = self.client.put(url).body(bytes.clone()); + let mut builder = self.client.put(url); if let Some(value) = self.client_options.get_content_type(location) { builder = builder.header(CONTENT_TYPE, value); } - match builder.send_retry(&self.retry_config).await { + let resp = builder + .header(CONTENT_LENGTH, payload.content_length()) + .retryable(&self.retry_config) + .idempotent(true) + .payload(Some(payload.clone())) + .send() + .await; + + match resp { Ok(response) => return Ok(response), Err(source) => match source.status() { // Some implementations return 404 instead of 409 @@ -189,7 +198,9 @@ impl Client { .client .request(method, url) .header("Depth", depth) - .send_retry_with_idempotency(&self.retry_config, true) + .retryable(&self.retry_config) + .idempotent(true) + .send() .await; let response = match result { diff --git a/object_store/src/http/mod.rs b/object_store/src/http/mod.rs index 626337df27f9..a838a0f479d9 100644 --- a/object_store/src/http/mod.rs +++ b/object_store/src/http/mod.rs @@ -32,7 +32,6 @@ //! [WebDAV]: https://en.wikipedia.org/wiki/WebDAV use async_trait::async_trait; -use bytes::Bytes; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; @@ -45,7 +44,7 @@ use crate::http::client::Client; use crate::path::Path; use crate::{ ClientConfigKey, ClientOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, - ObjectStore, PutMode, PutOptions, PutResult, Result, RetryConfig, + ObjectStore, PutMode, PutOptions, PutPayload, PutResult, Result, RetryConfig, }; mod client; @@ -95,13 +94,18 @@ impl std::fmt::Display for HttpStore { #[async_trait] impl ObjectStore for HttpStore { - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { if opts.mode != PutMode::Overwrite { // TODO: Add support for If header - https://datatracker.ietf.org/doc/html/rfc2518#section-9.4 return Err(crate::Error::NotImplemented); } - let response = self.client.put(location, bytes).await?; + let response = self.client.put(location, payload).await?; let e_tag = match get_etag(response.headers()) { Ok(e_tag) => Some(e_tag), Err(crate::client::header::Error::MissingEtag) => None, diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 97604a7dce68..692160a03596 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -251,9 +251,8 @@ //! //! ``` //! # use object_store::local::LocalFileSystem; -//! # use object_store::ObjectStore; +//! # use object_store::{ObjectStore, PutPayload}; //! # use std::sync::Arc; -//! # use bytes::Bytes; //! # use object_store::path::Path; //! # fn get_object_store() -> Arc { //! # Arc::new(LocalFileSystem::new()) @@ -262,8 +261,8 @@ //! # //! let object_store: Arc = get_object_store(); //! let path = Path::from("data/file1"); -//! let bytes = Bytes::from_static(b"hello"); -//! object_store.put(&path, bytes).await.unwrap(); +//! let payload = PutPayload::from_static(b"hello"); +//! object_store.put(&path, payload).await.unwrap(); //! # } //! ``` //! @@ -427,7 +426,7 @@ //! let new = do_update(r.bytes().await.unwrap()); //! //! // Attempt to commit transaction -//! match store.put_opts(&path, new, PutMode::Update(version).into()).await { +//! match store.put_opts(&path, new.into(), PutMode::Update(version).into()).await { //! Ok(_) => break, // Successfully committed //! Err(Error::Precondition { .. }) => continue, // Object has changed, try again //! Err(e) => panic!("{e}") @@ -498,17 +497,18 @@ pub use tags::TagSet; pub mod multipart; mod parse; +mod payload; mod upload; mod util; pub use parse::{parse_url, parse_url_opts}; +pub use payload::*; pub use upload::*; -pub use util::GetRange; +pub use util::{coalesce_ranges, collect_bytes, GetRange, OBJECT_STORE_COALESCE_DEFAULT}; use crate::path::Path; #[cfg(not(target_arch = "wasm32"))] use crate::util::maybe_spawn_blocking; -pub use crate::util::{coalesce_ranges, collect_bytes, OBJECT_STORE_COALESCE_DEFAULT}; use async_trait::async_trait; use bytes::Bytes; use chrono::{DateTime, Utc}; @@ -532,14 +532,20 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// Save the provided bytes to the specified location /// /// The operation is guaranteed to be atomic, it will either successfully - /// write the entirety of `bytes` to `location`, or fail. No clients + /// write the entirety of `payload` to `location`, or fail. No clients /// should be able to observe a partially written object - async fn put(&self, location: &Path, bytes: Bytes) -> Result { - self.put_opts(location, bytes, PutOptions::default()).await + async fn put(&self, location: &Path, payload: PutPayload) -> Result { + self.put_opts(location, payload, PutOptions::default()) + .await } - /// Save the provided bytes to the specified location with the given options - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result; + /// Save the provided `payload` to `location` with the given options + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result; /// Perform a multipart upload /// @@ -616,11 +622,10 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// # use object_store::{ObjectStore, ObjectMeta}; /// # use object_store::path::Path; /// # use futures::{StreamExt, TryStreamExt}; - /// # use bytes::Bytes; /// # /// // Create two objects - /// store.put(&Path::from("foo"), Bytes::from("foo")).await?; - /// store.put(&Path::from("bar"), Bytes::from("bar")).await?; + /// store.put(&Path::from("foo"), "foo".into()).await?; + /// store.put(&Path::from("bar"), "bar".into()).await?; /// /// // List object /// let locations = store.list(None).map_ok(|m| m.location).boxed(); @@ -717,17 +722,17 @@ macro_rules! as_ref_impl { ($type:ty) => { #[async_trait] impl ObjectStore for $type { - async fn put(&self, location: &Path, bytes: Bytes) -> Result { - self.as_ref().put(location, bytes).await + async fn put(&self, location: &Path, payload: PutPayload) -> Result { + self.as_ref().put(location, payload).await } async fn put_opts( &self, location: &Path, - bytes: Bytes, + payload: PutPayload, opts: PutOptions, ) -> Result { - self.as_ref().put_opts(location, bytes, opts).await + self.as_ref().put_opts(location, payload, opts).await } async fn put_multipart(&self, location: &Path) -> Result> { @@ -1219,8 +1224,7 @@ mod tests { let location = Path::from("test_dir/test_file.json"); let data = Bytes::from("arbitrary data"); - let expected_data = data.clone(); - storage.put(&location, data).await.unwrap(); + storage.put(&location, data.clone().into()).await.unwrap(); let root = Path::from("/"); @@ -1263,14 +1267,14 @@ mod tests { assert!(content_list.is_empty()); let read_data = storage.get(&location).await.unwrap().bytes().await.unwrap(); - assert_eq!(&*read_data, expected_data); + assert_eq!(&*read_data, data); // Test range request let range = 3..7; let range_result = storage.get_range(&location, range.clone()).await; let bytes = range_result.unwrap(); - assert_eq!(bytes, expected_data.slice(range.clone())); + assert_eq!(bytes, data.slice(range.clone())); let opts = GetOptions { range: Some(GetRange::Bounded(2..5)), @@ -1348,11 +1352,11 @@ mod tests { let ranges = vec![0..1, 2..3, 0..5]; let bytes = storage.get_ranges(&location, &ranges).await.unwrap(); for (range, bytes) in ranges.iter().zip(bytes) { - assert_eq!(bytes, expected_data.slice(range.clone())) + assert_eq!(bytes, data.slice(range.clone())) } let head = storage.head(&location).await.unwrap(); - assert_eq!(head.size, expected_data.len()); + assert_eq!(head.size, data.len()); storage.delete(&location).await.unwrap(); @@ -1369,7 +1373,7 @@ mod tests { let file_with_delimiter = Path::from_iter(["a", "b/c", "foo.file"]); storage - .put(&file_with_delimiter, Bytes::from("arbitrary")) + .put(&file_with_delimiter, "arbitrary".into()) .await .unwrap(); @@ -1409,10 +1413,7 @@ mod tests { let emoji_prefix = Path::from("🙀"); let emoji_file = Path::from("🙀/😀.parquet"); - storage - .put(&emoji_file, Bytes::from("arbitrary")) - .await - .unwrap(); + storage.put(&emoji_file, "arbitrary".into()).await.unwrap(); storage.head(&emoji_file).await.unwrap(); storage @@ -1464,7 +1465,7 @@ mod tests { let hello_prefix = Path::parse("%48%45%4C%4C%4F").unwrap(); let path = hello_prefix.child("foo.parquet"); - storage.put(&path, Bytes::from(vec![0, 1])).await.unwrap(); + storage.put(&path, vec![0, 1].into()).await.unwrap(); let files = flatten_list_stream(storage, Some(&hello_prefix)) .await .unwrap(); @@ -1504,7 +1505,7 @@ mod tests { // Can also write non-percent encoded sequences let path = Path::parse("%Q.parquet").unwrap(); - storage.put(&path, Bytes::from(vec![0, 1])).await.unwrap(); + storage.put(&path, vec![0, 1].into()).await.unwrap(); let files = flatten_list_stream(storage, None).await.unwrap(); assert_eq!(files, vec![path.clone()]); @@ -1512,7 +1513,7 @@ mod tests { storage.delete(&path).await.unwrap(); let path = Path::parse("foo bar/I contain spaces.parquet").unwrap(); - storage.put(&path, Bytes::from(vec![0, 1])).await.unwrap(); + storage.put(&path, vec![0, 1].into()).await.unwrap(); storage.head(&path).await.unwrap(); let files = flatten_list_stream(storage, Some(&Path::from("foo bar"))) @@ -1622,7 +1623,7 @@ mod tests { delete_fixtures(storage).await; let path = Path::from("empty"); - storage.put(&path, Bytes::new()).await.unwrap(); + storage.put(&path, PutPayload::default()).await.unwrap(); let meta = storage.head(&path).await.unwrap(); assert_eq!(meta.size, 0); let data = storage.get(&path).await.unwrap().bytes().await.unwrap(); @@ -1879,7 +1880,7 @@ mod tests { let data = get_chunks(5 * 1024 * 1024, 3); let bytes_expected = data.concat(); let mut upload = storage.put_multipart(&location).await.unwrap(); - let uploads = data.into_iter().map(|x| upload.put_part(x)); + let uploads = data.into_iter().map(|x| upload.put_part(x.into())); futures::future::try_join_all(uploads).await.unwrap(); // Object should not yet exist in store @@ -1928,7 +1929,7 @@ mod tests { // We can abort an in-progress write let mut upload = storage.put_multipart(&location).await.unwrap(); upload - .put_part(data.first().unwrap().clone()) + .put_part(data.first().unwrap().clone().into()) .await .unwrap(); @@ -1953,7 +1954,7 @@ mod tests { let location1 = Path::from("foo/x.json"); let location2 = Path::from("foo.bar/y.json"); - let data = Bytes::from("arbitrary data"); + let data = PutPayload::from("arbitrary data"); storage.put(&location1, data.clone()).await.unwrap(); storage.put(&location2, data).await.unwrap(); @@ -2011,8 +2012,7 @@ mod tests { .collect(); for f in &files { - let data = data.clone(); - storage.put(f, data).await.unwrap(); + storage.put(f, data.clone().into()).await.unwrap(); } // ==================== check: prefix-list `mydb/wb` (directory) ==================== @@ -2076,15 +2076,15 @@ mod tests { let contents2 = Bytes::from("dogs"); // copy() make both objects identical - storage.put(&path1, contents1.clone()).await.unwrap(); - storage.put(&path2, contents2.clone()).await.unwrap(); + storage.put(&path1, contents1.clone().into()).await.unwrap(); + storage.put(&path2, contents2.clone().into()).await.unwrap(); storage.copy(&path1, &path2).await.unwrap(); let new_contents = storage.get(&path2).await.unwrap().bytes().await.unwrap(); assert_eq!(&new_contents, &contents1); // rename() copies contents and deletes original - storage.put(&path1, contents1.clone()).await.unwrap(); - storage.put(&path2, contents2.clone()).await.unwrap(); + storage.put(&path1, contents1.clone().into()).await.unwrap(); + storage.put(&path2, contents2.clone().into()).await.unwrap(); storage.rename(&path1, &path2).await.unwrap(); let new_contents = storage.get(&path2).await.unwrap().bytes().await.unwrap(); assert_eq!(&new_contents, &contents1); @@ -2104,8 +2104,8 @@ mod tests { let contents2 = Bytes::from("dogs"); // copy_if_not_exists() errors if destination already exists - storage.put(&path1, contents1.clone()).await.unwrap(); - storage.put(&path2, contents2.clone()).await.unwrap(); + storage.put(&path1, contents1.clone().into()).await.unwrap(); + storage.put(&path2, contents2.clone().into()).await.unwrap(); let result = storage.copy_if_not_exists(&path1, &path2).await; assert!(result.is_err()); assert!(matches!( @@ -2133,7 +2133,7 @@ mod tests { // Create destination object let path2 = Path::from("test2"); - storage.put(&path2, Bytes::from("hello")).await.unwrap(); + storage.put(&path2, "hello".into()).await.unwrap(); // copy() errors if source does not exist let result = storage.copy(&path1, &path2).await; @@ -2164,7 +2164,7 @@ mod tests { let parts: Vec<_> = futures::stream::iter(chunks) .enumerate() - .map(|(idx, b)| multipart.put_part(&path, &id, idx, b)) + .map(|(idx, b)| multipart.put_part(&path, &id, idx, b.into())) .buffered(2) .try_collect() .await @@ -2204,7 +2204,7 @@ mod tests { let data = Bytes::from("hello world"); let path = Path::from("file.txt"); - integration.put(&path, data.clone()).await.unwrap(); + integration.put(&path, data.clone().into()).await.unwrap(); let signed = integration .signed_url(Method::GET, &path, Duration::from_secs(60)) diff --git a/object_store/src/limit.rs b/object_store/src/limit.rs index e5f6841638e1..b94aa05b8b6e 100644 --- a/object_store/src/limit.rs +++ b/object_store/src/limit.rs @@ -19,7 +19,7 @@ use crate::{ BoxStream, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta, - ObjectStore, Path, PutOptions, PutResult, Result, StreamExt, UploadPart, + ObjectStore, Path, PutOptions, PutPayload, PutResult, Result, StreamExt, UploadPart, }; use async_trait::async_trait; use bytes::Bytes; @@ -70,14 +70,19 @@ impl std::fmt::Display for LimitStore { #[async_trait] impl ObjectStore for LimitStore { - async fn put(&self, location: &Path, bytes: Bytes) -> Result { + async fn put(&self, location: &Path, payload: PutPayload) -> Result { let _permit = self.semaphore.acquire().await.unwrap(); - self.inner.put(location, bytes).await + self.inner.put(location, payload).await } - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { let _permit = self.semaphore.acquire().await.unwrap(); - self.inner.put_opts(location, bytes, opts).await + self.inner.put_opts(location, payload, opts).await } async fn put_multipart(&self, location: &Path) -> Result> { let upload = self.inner.put_multipart(location).await?; @@ -232,7 +237,7 @@ impl LimitUpload { #[async_trait] impl MultipartUpload for LimitUpload { - fn put_part(&mut self, data: Bytes) -> UploadPart { + fn put_part(&mut self, data: PutPayload) -> UploadPart { let upload = self.upload.put_part(data); let s = Arc::clone(&self.semaphore); Box::pin(async move { diff --git a/object_store/src/local.rs b/object_store/src/local.rs index 6cc0c672af45..0d7c279b3190 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -39,7 +39,7 @@ use crate::{ path::{absolute_path_to_url, Path}, util::InvalidGetRange, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta, ObjectStore, - PutMode, PutOptions, PutResult, Result, UploadPart, + PutMode, PutOptions, PutPayload, PutResult, Result, UploadPart, }; /// A specialized `Error` for filesystem object store-related errors @@ -336,7 +336,12 @@ fn is_valid_file_path(path: &Path) -> bool { #[async_trait] impl ObjectStore for LocalFileSystem { - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { if matches!(opts.mode, PutMode::Update(_)) { return Err(crate::Error::NotImplemented); } @@ -346,7 +351,7 @@ impl ObjectStore for LocalFileSystem { let (mut file, staging_path) = new_staged_upload(&path)?; let mut e_tag = None; - let err = match file.write_all(&bytes) { + let err = match payload.iter().try_for_each(|x| file.write_all(x)) { Ok(_) => { let metadata = file.metadata().map_err(|e| Error::Metadata { source: e.into(), @@ -724,9 +729,9 @@ impl LocalUpload { #[async_trait] impl MultipartUpload for LocalUpload { - fn put_part(&mut self, data: Bytes) -> UploadPart { + fn put_part(&mut self, data: PutPayload) -> UploadPart { let offset = self.offset; - self.offset += data.len() as u64; + self.offset += data.content_length() as u64; let s = Arc::clone(&self.state); maybe_spawn_blocking(move || { @@ -734,7 +739,11 @@ impl MultipartUpload for LocalUpload { let file = f.as_mut().context(AbortedSnafu)?; file.seek(SeekFrom::Start(offset)) .context(SeekSnafu { path: &s.dest })?; - file.write_all(&data).context(UnableToCopyDataToFileSnafu)?; + + data.iter() + .try_for_each(|x| file.write_all(x)) + .context(UnableToCopyDataToFileSnafu)?; + Ok(()) }) .boxed() @@ -1016,8 +1025,8 @@ mod tests { // Can't use stream_get test as WriteMultipart uses a tokio JoinSet let p = Path::from("manual_upload"); let mut upload = integration.put_multipart(&p).await.unwrap(); - upload.put_part(Bytes::from_static(b"123")).await.unwrap(); - upload.put_part(Bytes::from_static(b"45678")).await.unwrap(); + upload.put_part("123".into()).await.unwrap(); + upload.put_part("45678".into()).await.unwrap(); let r = upload.complete().await.unwrap(); let get = integration.get(&p).await.unwrap(); @@ -1035,9 +1044,11 @@ mod tests { let location = Path::from("nested/file/test_file"); let data = Bytes::from("arbitrary data"); - let expected_data = data.clone(); - integration.put(&location, data).await.unwrap(); + integration + .put(&location, data.clone().into()) + .await + .unwrap(); let read_data = integration .get(&location) @@ -1046,7 +1057,7 @@ mod tests { .bytes() .await .unwrap(); - assert_eq!(&*read_data, expected_data); + assert_eq!(&*read_data, data); } #[tokio::test] @@ -1057,9 +1068,11 @@ mod tests { let location = Path::from("some_file"); let data = Bytes::from("arbitrary data"); - let expected_data = data.clone(); - integration.put(&location, data).await.unwrap(); + integration + .put(&location, data.clone().into()) + .await + .unwrap(); let read_data = integration .get(&location) @@ -1068,7 +1081,7 @@ mod tests { .bytes() .await .unwrap(); - assert_eq!(&*read_data, expected_data); + assert_eq!(&*read_data, data); } #[tokio::test] @@ -1260,7 +1273,7 @@ mod tests { // Adding a file through a symlink creates in both paths integration - .put(&Path::from("b/file.parquet"), Bytes::from(vec![0, 1, 2])) + .put(&Path::from("b/file.parquet"), vec![0, 1, 2].into()) .await .unwrap(); @@ -1279,7 +1292,7 @@ mod tests { let directory = Path::from("directory"); let object = directory.child("child.txt"); let data = Bytes::from("arbitrary"); - integration.put(&object, data.clone()).await.unwrap(); + integration.put(&object, data.clone().into()).await.unwrap(); integration.head(&object).await.unwrap(); let result = integration.get(&object).await.unwrap(); assert_eq!(result.bytes().await.unwrap(), data); @@ -1319,7 +1332,7 @@ mod tests { let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); let location = Path::from("some_file"); - let data = Bytes::from("arbitrary data"); + let data = PutPayload::from("arbitrary data"); let mut u1 = integration.put_multipart(&location).await.unwrap(); u1.put_part(data.clone()).await.unwrap(); @@ -1418,12 +1431,10 @@ mod tests { #[cfg(test)] mod not_wasm_tests { use std::time::Duration; - - use bytes::Bytes; use tempfile::TempDir; use crate::local::LocalFileSystem; - use crate::{ObjectStore, Path}; + use crate::{ObjectStore, Path, PutPayload}; #[tokio::test] async fn test_cleanup_intermediate_files() { @@ -1431,7 +1442,7 @@ mod not_wasm_tests { let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); let location = Path::from("some_file"); - let data = Bytes::from_static(b"hello"); + let data = PutPayload::from_static(b"hello"); let mut upload = integration.put_multipart(&location).await.unwrap(); upload.put_part(data).await.unwrap(); diff --git a/object_store/src/memory.rs b/object_store/src/memory.rs index 6c960d4f24fb..d42e6f231c04 100644 --- a/object_store/src/memory.rs +++ b/object_store/src/memory.rs @@ -29,11 +29,11 @@ use snafu::{OptionExt, ResultExt, Snafu}; use crate::multipart::{MultipartStore, PartId}; use crate::util::InvalidGetRange; -use crate::GetOptions; use crate::{ path::Path, GetRange, GetResult, GetResultPayload, ListResult, MultipartId, MultipartUpload, ObjectMeta, ObjectStore, PutMode, PutOptions, PutResult, Result, UpdateVersion, UploadPart, }; +use crate::{GetOptions, PutPayload}; /// A specialized `Error` for in-memory object store-related errors #[derive(Debug, Snafu)] @@ -192,10 +192,15 @@ impl std::fmt::Display for InMemory { #[async_trait] impl ObjectStore for InMemory { - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { let mut storage = self.storage.write(); let etag = storage.next_etag; - let entry = Entry::new(bytes, Utc::now(), etag); + let entry = Entry::new(payload.into(), Utc::now(), etag); match opts.mode { PutMode::Overwrite => storage.overwrite(location, entry), @@ -391,14 +396,14 @@ impl MultipartStore for InMemory { _path: &Path, id: &MultipartId, part_idx: usize, - data: Bytes, + payload: PutPayload, ) -> Result { let mut storage = self.storage.write(); let upload = storage.upload_mut(id)?; if part_idx <= upload.parts.len() { upload.parts.resize(part_idx + 1, None); } - upload.parts[part_idx] = Some(data); + upload.parts[part_idx] = Some(payload.into()); Ok(PartId { content_id: Default::default(), }) @@ -471,21 +476,22 @@ impl InMemory { #[derive(Debug)] struct InMemoryUpload { location: Path, - parts: Vec, + parts: Vec, storage: Arc>, } #[async_trait] impl MultipartUpload for InMemoryUpload { - fn put_part(&mut self, data: Bytes) -> UploadPart { - self.parts.push(data); + fn put_part(&mut self, payload: PutPayload) -> UploadPart { + self.parts.push(payload); Box::pin(futures::future::ready(Ok(()))) } async fn complete(&mut self) -> Result { - let cap = self.parts.iter().map(|x| x.len()).sum(); + let cap = self.parts.iter().map(|x| x.content_length()).sum(); let mut buf = Vec::with_capacity(cap); - self.parts.iter().for_each(|x| buf.extend_from_slice(x)); + let parts = self.parts.iter().flatten(); + parts.for_each(|x| buf.extend_from_slice(x)); let etag = self.storage.write().insert(&self.location, buf.into()); Ok(PutResult { e_tag: Some(etag.to_string()), @@ -552,9 +558,11 @@ mod tests { let location = Path::from("some_file"); let data = Bytes::from("arbitrary data"); - let expected_data = data.clone(); - integration.put(&location, data).await.unwrap(); + integration + .put(&location, data.clone().into()) + .await + .unwrap(); let read_data = integration .get(&location) @@ -563,7 +571,7 @@ mod tests { .bytes() .await .unwrap(); - assert_eq!(&*read_data, expected_data); + assert_eq!(&*read_data, data); } const NON_EXISTENT_NAME: &str = "nonexistentname"; diff --git a/object_store/src/multipart.rs b/object_store/src/multipart.rs index 26cce3936244..d94e7f150513 100644 --- a/object_store/src/multipart.rs +++ b/object_store/src/multipart.rs @@ -22,10 +22,9 @@ //! especially useful when dealing with large files or high-throughput systems. use async_trait::async_trait; -use bytes::Bytes; use crate::path::Path; -use crate::{MultipartId, PutResult, Result}; +use crate::{MultipartId, PutPayload, PutResult, Result}; /// Represents a part of a file that has been successfully uploaded in a multipart upload process. #[derive(Debug, Clone)] @@ -64,7 +63,7 @@ pub trait MultipartStore: Send + Sync + 'static { path: &Path, id: &MultipartId, part_idx: usize, - data: Bytes, + data: PutPayload, ) -> Result; /// Completes a multipart upload diff --git a/object_store/src/payload.rs b/object_store/src/payload.rs new file mode 100644 index 000000000000..eba4f1bd6531 --- /dev/null +++ b/object_store/src/payload.rs @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use bytes::Bytes; +use std::sync::Arc; + +/// A cheaply cloneable, ordered collection of [`Bytes`] +#[derive(Debug, Clone)] +pub struct PutPayload(Arc<[Bytes]>); + +impl Default for PutPayload { + fn default() -> Self { + Self(Arc::new([])) + } +} + +impl PutPayload { + /// Create a new empty [`PutPayload`] + pub fn new() -> Self { + Self::default() + } + + /// Creates a [`PutPayload`] from a static slice + pub fn from_static(s: &'static [u8]) -> Self { + s.into() + } + + /// Creates a [`PutPayload`] from a [`Bytes`] + pub fn from_bytes(s: Bytes) -> Self { + s.into() + } + + #[cfg(feature = "cloud")] + pub(crate) fn body(&self) -> reqwest::Body { + reqwest::Body::wrap_stream(futures::stream::iter( + self.clone().into_iter().map(Ok::<_, crate::Error>), + )) + } + + /// Returns the total length of the [`Bytes`] in this payload + pub fn content_length(&self) -> usize { + self.0.iter().map(|b| b.len()).sum() + } + + /// Returns an iterator over the [`Bytes`] in this payload + pub fn iter(&self) -> PutPayloadIter<'_> { + PutPayloadIter(self.0.iter()) + } +} + +impl<'a> IntoIterator for &'a PutPayload { + type Item = &'a Bytes; + type IntoIter = PutPayloadIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl IntoIterator for PutPayload { + type Item = Bytes; + type IntoIter = PutPayloadIntoIter; + + fn into_iter(self) -> Self::IntoIter { + PutPayloadIntoIter { + payload: self, + idx: 0, + } + } +} + +/// An iterator over [`PutPayload`] +#[derive(Debug)] +pub struct PutPayloadIter<'a>(std::slice::Iter<'a, Bytes>); + +impl<'a> Iterator for PutPayloadIter<'a> { + type Item = &'a Bytes; + + fn next(&mut self) -> Option { + self.0.next() + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +/// An owning iterator of [`PutPayload`] +#[derive(Debug)] +pub struct PutPayloadIntoIter { + payload: PutPayload, + idx: usize, +} + +impl Iterator for PutPayloadIntoIter { + type Item = Bytes; + + fn next(&mut self) -> Option { + let p = self.payload.0.get(self.idx)?.clone(); + self.idx += 1; + Some(p) + } + + fn size_hint(&self) -> (usize, Option) { + let l = self.payload.0.len() - self.idx; + (l, Some(l)) + } +} + +impl From for PutPayload { + fn from(value: Bytes) -> Self { + Self(Arc::new([value])) + } +} + +impl From> for PutPayload { + fn from(value: Vec) -> Self { + Self(Arc::new([value.into()])) + } +} + +impl From<&'static str> for PutPayload { + fn from(value: &'static str) -> Self { + Bytes::from(value).into() + } +} + +impl From<&'static [u8]> for PutPayload { + fn from(value: &'static [u8]) -> Self { + Bytes::from(value).into() + } +} + +impl From for PutPayload { + fn from(value: String) -> Self { + Bytes::from(value).into() + } +} + +impl FromIterator for PutPayload { + fn from_iter>(iter: T) -> Self { + // TODO Use PutPayloadMut to avoid bump allocating + Bytes::from_iter(iter).into() + } +} + +impl FromIterator for PutPayload { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl From for Bytes { + fn from(value: PutPayload) -> Self { + match value.0.len() { + 0 => Self::new(), + 1 => value.0[0].clone(), + _ => { + let mut buf = Vec::with_capacity(value.content_length()); + value.iter().for_each(|x| buf.extend_from_slice(x)); + buf.into() + } + } + } +} diff --git a/object_store/src/prefix.rs b/object_store/src/prefix.rs index 053f71a2d063..1d1ffeed8c63 100644 --- a/object_store/src/prefix.rs +++ b/object_store/src/prefix.rs @@ -23,7 +23,7 @@ use std::ops::Range; use crate::path::Path; use crate::{ GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, PutOptions, - PutResult, Result, + PutPayload, PutResult, Result, }; #[doc(hidden)] @@ -80,14 +80,19 @@ impl PrefixStore { #[async_trait::async_trait] impl ObjectStore for PrefixStore { - async fn put(&self, location: &Path, bytes: Bytes) -> Result { + async fn put(&self, location: &Path, payload: PutPayload) -> Result { let full_path = self.full_path(location); - self.inner.put(&full_path, bytes).await + self.inner.put(&full_path, payload).await } - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { let full_path = self.full_path(location); - self.inner.put_opts(&full_path, bytes, opts).await + self.inner.put_opts(&full_path, payload, opts).await } async fn put_multipart(&self, location: &Path) -> Result> { @@ -218,9 +223,8 @@ mod tests { let location = Path::from("prefix/test_file.json"); let data = Bytes::from("arbitrary data"); - let expected_data = data.clone(); - local.put(&location, data).await.unwrap(); + local.put(&location, data.clone().into()).await.unwrap(); let prefix = PrefixStore::new(local, "prefix"); let location_prefix = Path::from("test_file.json"); @@ -239,11 +243,11 @@ mod tests { .bytes() .await .unwrap(); - assert_eq!(&*read_data, expected_data); + assert_eq!(&*read_data, data); let target_prefix = Path::from("/test_written.json"); prefix - .put(&target_prefix, expected_data.clone()) + .put(&target_prefix, data.clone().into()) .await .unwrap(); @@ -256,6 +260,6 @@ mod tests { let location = Path::from("prefix/test_written.json"); let read_data = local.get(&location).await.unwrap().bytes().await.unwrap(); - assert_eq!(&*read_data, expected_data) + assert_eq!(&*read_data, data) } } diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs index 65fac5922f69..d089784668e9 100644 --- a/object_store/src/throttle.rs +++ b/object_store/src/throttle.rs @@ -23,7 +23,7 @@ use std::{convert::TryInto, sync::Arc}; use crate::multipart::{MultipartStore, PartId}; use crate::{ path::Path, GetResult, GetResultPayload, ListResult, MultipartId, MultipartUpload, ObjectMeta, - ObjectStore, PutOptions, PutResult, Result, + ObjectStore, PutOptions, PutPayload, PutResult, Result, }; use crate::{GetOptions, UploadPart}; use async_trait::async_trait; @@ -148,14 +148,19 @@ impl std::fmt::Display for ThrottledStore { #[async_trait] impl ObjectStore for ThrottledStore { - async fn put(&self, location: &Path, bytes: Bytes) -> Result { + async fn put(&self, location: &Path, payload: PutPayload) -> Result { sleep(self.config().wait_put_per_call).await; - self.inner.put(location, bytes).await + self.inner.put(location, payload).await } - async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { sleep(self.config().wait_put_per_call).await; - self.inner.put_opts(location, bytes, opts).await + self.inner.put_opts(location, payload, opts).await } async fn put_multipart(&self, location: &Path) -> Result> { @@ -332,7 +337,7 @@ impl MultipartStore for ThrottledStore { path: &Path, id: &MultipartId, part_idx: usize, - data: Bytes, + data: PutPayload, ) -> Result { sleep(self.config().wait_put_per_call).await; self.inner.put_part(path, id, part_idx, data).await @@ -360,7 +365,7 @@ struct ThrottledUpload { #[async_trait] impl MultipartUpload for ThrottledUpload { - fn put_part(&mut self, data: Bytes) -> UploadPart { + fn put_part(&mut self, data: PutPayload) -> UploadPart { let duration = self.sleep; let put = self.upload.put_part(data); Box::pin(async move { @@ -382,7 +387,6 @@ impl MultipartUpload for ThrottledUpload { mod tests { use super::*; use crate::{memory::InMemory, tests::*, GetResultPayload}; - use bytes::Bytes; use futures::TryStreamExt; use tokio::time::Duration; use tokio::time::Instant; @@ -536,8 +540,7 @@ mod tests { if let Some(n_bytes) = n_bytes { let data: Vec<_> = std::iter::repeat(1u8).take(n_bytes).collect(); - let bytes = Bytes::from(data); - store.put(&path, bytes).await.unwrap(); + store.put(&path, data.into()).await.unwrap(); } else { // ensure object is absent store.delete(&path).await.unwrap(); @@ -560,9 +563,7 @@ mod tests { // create new entries for i in 0..n_entries { let path = prefix.child(i.to_string().as_str()); - - let data = Bytes::from("bar"); - store.put(&path, data).await.unwrap(); + store.put(&path, "bar".into()).await.unwrap(); } prefix @@ -630,10 +631,9 @@ mod tests { async fn measure_put(store: &ThrottledStore, n_bytes: usize) -> Duration { let data: Vec<_> = std::iter::repeat(1u8).take(n_bytes).collect(); - let bytes = Bytes::from(data); let t0 = Instant::now(); - store.put(&Path::from("foo"), bytes).await.unwrap(); + store.put(&Path::from("foo"), data.into()).await.unwrap(); t0.elapsed() } diff --git a/object_store/src/upload.rs b/object_store/src/upload.rs index fe864e2821c9..d5c34562cf52 100644 --- a/object_store/src/upload.rs +++ b/object_store/src/upload.rs @@ -17,14 +17,12 @@ use std::task::{Context, Poll}; +use crate::{PutPayload, PutResult, Result}; use async_trait::async_trait; -use bytes::Bytes; use futures::future::BoxFuture; use futures::ready; use tokio::task::JoinSet; -use crate::{PutResult, Result}; - /// An upload part request pub type UploadPart = BoxFuture<'static, Result<()>>; @@ -65,7 +63,7 @@ pub trait MultipartUpload: Send + std::fmt::Debug { /// ``` /// /// [R2]: https://developers.cloudflare.com/r2/objects/multipart-objects/#limitations - fn put_part(&mut self, data: Bytes) -> UploadPart; + fn put_part(&mut self, data: PutPayload) -> UploadPart; /// Complete the multipart upload /// @@ -169,7 +167,7 @@ impl WriteMultipart { } } - fn put_part(&mut self, part: Bytes) { + fn put_part(&mut self, part: PutPayload) { self.tasks.spawn(self.upload.put_part(part)); } diff --git a/object_store/tests/get_range_file.rs b/object_store/tests/get_range_file.rs index 309a86d8fe9d..59c593400450 100644 --- a/object_store/tests/get_range_file.rs +++ b/object_store/tests/get_range_file.rs @@ -37,8 +37,13 @@ impl std::fmt::Display for MyStore { #[async_trait] impl ObjectStore for MyStore { - async fn put_opts(&self, path: &Path, data: Bytes, opts: PutOptions) -> Result { - self.0.put_opts(path, data, opts).await + async fn put_opts( + &self, + location: &Path, + payload: PutPayload, + opts: PutOptions, + ) -> Result { + self.0.put_opts(location, payload, opts).await } async fn put_multipart(&self, _location: &Path) -> Result> { @@ -77,7 +82,7 @@ async fn test_get_range() { let path = Path::from("foo"); let expected = Bytes::from_static(b"hello world"); - store.put(&path, expected.clone()).await.unwrap(); + store.put(&path, expected.clone().into()).await.unwrap(); let fetched = store.get(&path).await.unwrap().bytes().await.unwrap(); assert_eq!(expected, fetched); @@ -101,7 +106,7 @@ async fn test_get_opts_over_range() { let path = Path::from("foo"); let expected = Bytes::from_static(b"hello world"); - store.put(&path, expected.clone()).await.unwrap(); + store.put(&path, expected.clone().into()).await.unwrap(); let opts = GetOptions { range: Some(GetRange::Bounded(0..(expected.len() * 2))),