diff --git a/object_store/src/aws/builder.rs b/object_store/src/aws/builder.rs index 75a5299a0859..79ea75b5aba2 100644 --- a/object_store/src/aws/builder.rs +++ b/object_store/src/aws/builder.rs @@ -20,7 +20,8 @@ use crate::aws::credential::{ InstanceCredentialProvider, TaskCredentialProvider, WebIdentityProvider, }; use crate::aws::{ - AmazonS3, AwsCredential, AwsCredentialProvider, Checksum, S3CopyIfNotExists, STORE, + AmazonS3, AwsCredential, AwsCredentialProvider, Checksum, S3ConditionalPut, S3CopyIfNotExists, + STORE, }; use crate::client::TokenCredentialProvider; use crate::config::ConfigValue; @@ -152,6 +153,8 @@ pub struct AmazonS3Builder { skip_signature: ConfigValue, /// Copy if not exists copy_if_not_exists: Option>, + /// Put precondition + conditional_put: Option>, } /// Configuration keys for [`AmazonS3Builder`] @@ -288,6 +291,11 @@ pub enum AmazonS3ConfigKey { /// See [`S3CopyIfNotExists`] CopyIfNotExists, + /// Configure how to provide conditional put operations + /// + /// See [`S3ConditionalPut`] + ConditionalPut, + /// Skip signing request SkipSignature, @@ -312,7 +320,8 @@ impl AsRef for AmazonS3ConfigKey { Self::Checksum => "aws_checksum_algorithm", Self::ContainerCredentialsRelativeUri => "aws_container_credentials_relative_uri", Self::SkipSignature => "aws_skip_signature", - Self::CopyIfNotExists => "copy_if_not_exists", + Self::CopyIfNotExists => "aws_copy_if_not_exists", + Self::ConditionalPut => "aws_conditional_put", Self::Client(opt) => opt.as_ref(), } } @@ -339,7 +348,8 @@ impl FromStr for AmazonS3ConfigKey { "aws_checksum_algorithm" | "checksum_algorithm" => Ok(Self::Checksum), "aws_container_credentials_relative_uri" => Ok(Self::ContainerCredentialsRelativeUri), "aws_skip_signature" | "skip_signature" => Ok(Self::SkipSignature), - "copy_if_not_exists" => Ok(Self::CopyIfNotExists), + "aws_copy_if_not_exists" | "copy_if_not_exists" => Ok(Self::CopyIfNotExists), + "aws_conditional_put" | "conditional_put" => Ok(Self::ConditionalPut), // Backwards compatibility "aws_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), _ => match s.parse() { @@ -446,6 +456,9 @@ impl AmazonS3Builder { AmazonS3ConfigKey::CopyIfNotExists => { self.copy_if_not_exists = Some(ConfigValue::Deferred(value.into())) } + AmazonS3ConfigKey::ConditionalPut => { + self.conditional_put = Some(ConfigValue::Deferred(value.into())) + } }; self } @@ -509,6 +522,9 @@ impl AmazonS3Builder { AmazonS3ConfigKey::CopyIfNotExists => { self.copy_if_not_exists.as_ref().map(ToString::to_string) } + AmazonS3ConfigKey::ConditionalPut => { + self.conditional_put.as_ref().map(ToString::to_string) + } } } @@ -713,6 +729,12 @@ impl AmazonS3Builder { self } + /// Configure how to provide conditional put operations + pub fn with_conditional_put(mut self, config: S3ConditionalPut) -> Self { + self.conditional_put = Some(config.into()); + self + } + /// Create a [`AmazonS3`] instance from the provided values, /// consuming `self`. pub fn build(mut self) -> Result { @@ -724,6 +746,7 @@ impl AmazonS3Builder { let region = self.region.context(MissingRegionSnafu)?; let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?; let copy_if_not_exists = self.copy_if_not_exists.map(|x| x.get()).transpose()?; + let put_precondition = self.conditional_put.map(|x| x.get()).transpose()?; let credentials = if let Some(credentials) = self.credentials { credentials @@ -830,6 +853,7 @@ impl AmazonS3Builder { skip_signature: self.skip_signature.get()?, checksum, copy_if_not_exists, + conditional_put: put_precondition, }; let client = Arc::new(S3Client::new(config)?); diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index 4e98f259f8dd..20c2a96b57cd 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -17,13 +17,18 @@ use crate::aws::checksum::Checksum; use crate::aws::credential::{AwsCredential, CredentialExt}; -use crate::aws::{AwsCredentialProvider, S3CopyIfNotExists, STORE, STRICT_PATH_ENCODE_SET}; +use crate::aws::{ + AwsCredentialProvider, S3ConditionalPut, S3CopyIfNotExists, STORE, STRICT_PATH_ENCODE_SET, +}; use crate::client::get::GetClient; -use crate::client::header::get_etag; use crate::client::header::HeaderConfig; +use crate::client::header::{get_put_result, get_version}; use crate::client::list::ListClient; -use crate::client::list_response::ListResponse; use crate::client::retry::RetryExt; +use crate::client::s3::{ + CompleteMultipartUpload, CompleteMultipartUploadResult, InitiateMultipartUploadResult, + ListResponse, +}; use crate::client::GetOptionsExt; use crate::multipart::PartId; use crate::path::DELIMITER; @@ -34,17 +39,20 @@ use async_trait::async_trait; use base64::prelude::BASE64_STANDARD; use base64::Engine; use bytes::{Buf, Bytes}; +use hyper::http::HeaderName; use itertools::Itertools; use percent_encoding::{utf8_percent_encode, PercentEncode}; use quick_xml::events::{self as xml_events}; use reqwest::{ header::{CONTENT_LENGTH, CONTENT_TYPE}, - Client as ReqwestClient, Method, Response, StatusCode, + Client as ReqwestClient, Method, RequestBuilder, Response, StatusCode, }; use serde::{Deserialize, Serialize}; use snafu::{ResultExt, Snafu}; use std::sync::Arc; +const VERSION_HEADER: &str = "x-amz-version-id"; + /// A specialized `Error` for object store-related errors #[derive(Debug, Snafu)] #[allow(missing_docs)] @@ -147,33 +155,6 @@ impl From for crate::Error { } } -#[derive(Debug, Deserialize)] -#[serde(rename_all = "PascalCase")] -struct InitiateMultipart { - upload_id: String, -} - -#[derive(Debug, Serialize)] -#[serde(rename_all = "PascalCase", rename = "CompleteMultipartUpload")] -struct CompleteMultipart { - part: Vec, -} - -#[derive(Debug, Serialize)] -struct MultipartPart { - #[serde(rename = "ETag")] - e_tag: String, - #[serde(rename = "PartNumber")] - part_number: usize, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "PascalCase", rename = "CompleteMultipartUploadResult")] -struct CompleteMultipartResult { - #[serde(rename = "ETag")] - e_tag: String, -} - #[derive(Deserialize)] #[serde(rename_all = "PascalCase", rename = "DeleteResult")] struct BatchDeleteResponse { @@ -225,12 +206,61 @@ pub struct S3Config { pub skip_signature: bool, pub checksum: Option, pub copy_if_not_exists: Option, + pub conditional_put: Option, } impl S3Config { pub(crate) fn path_url(&self, path: &Path) -> String { format!("{}/{}", self.bucket_endpoint, encode_path(path)) } + + async fn get_credential(&self) -> Result>> { + Ok(match self.skip_signature { + false => Some(self.credentials.get_credential().await?), + true => None, + }) + } +} + +/// A builder for a put request allowing customisation of the headers and query string +pub(crate) struct PutRequest<'a> { + path: &'a Path, + config: &'a S3Config, + builder: RequestBuilder, + payload_sha256: Option>, +} + +impl<'a> PutRequest<'a> { + pub fn query(self, query: &T) -> Self { + let builder = self.builder.query(query); + Self { builder, ..self } + } + + pub fn header(self, k: &HeaderName, v: &str) -> Self { + let builder = self.builder.header(k, v); + Self { builder, ..self } + } + + pub async fn send(self) -> Result { + let credential = self.config.get_credential().await?; + + let response = self + .builder + .with_aws_sigv4( + credential.as_deref(), + &self.config.region, + "s3", + self.config.sign_payload, + self.payload_sha256.as_deref(), + ) + .send_retry(&self.config.retry_config) + .await + .context(PutRequestSnafu { + path: self.path.as_ref(), + })?; + + Ok(get_put_result(response.headers(), VERSION_HEADER).context(MetadataSnafu)?) + } } #[derive(Debug)] @@ -250,23 +280,10 @@ impl S3Client { &self.config } - async fn get_credential(&self) -> Result>> { - Ok(match self.config.skip_signature { - false => Some(self.config.credentials.get_credential().await?), - true => None, - }) - } - /// Make an S3 PUT request /// /// Returns the ETag - pub async fn put_request( - &self, - path: &Path, - bytes: Bytes, - query: &T, - ) -> Result { - let credential = self.get_credential().await?; + pub fn put_request<'a>(&'a self, path: &'a Path, bytes: Bytes) -> PutRequest<'a> { let url = self.config.path_url(path); let mut builder = self.client.request(Method::PUT, url); let mut payload_sha256 = None; @@ -288,22 +305,12 @@ impl S3Client { builder = builder.header(CONTENT_TYPE, value); } - let response = builder - .query(query) - .with_aws_sigv4( - credential.as_deref(), - &self.config.region, - "s3", - self.config.sign_payload, - payload_sha256.as_deref(), - ) - .send_retry(&self.config.retry_config) - .await - .context(PutRequestSnafu { - path: path.as_ref(), - })?; - - Ok(get_etag(response.headers()).context(MetadataSnafu)?) + PutRequest { + path, + builder, + payload_sha256, + config: &self.config, + } } /// Make an S3 Delete request @@ -312,7 +319,7 @@ impl S3Client { path: &Path, query: &T, ) -> Result<()> { - let credential = self.get_credential().await?; + let credential = self.config.get_credential().await?; let url = self.config.path_url(path); self.client @@ -346,7 +353,7 @@ impl S3Client { return Ok(Vec::new()); } - let credential = self.get_credential().await?; + let credential = self.config.get_credential().await?; let url = format!("{}?delete", self.config.bucket_endpoint); let mut buffer = Vec::new(); @@ -444,7 +451,7 @@ impl S3Client { /// Make an S3 Copy request pub async fn copy_request(&self, from: &Path, to: &Path, overwrite: bool) -> Result<()> { - let credential = self.get_credential().await?; + let credential = self.config.get_credential().await?; let url = self.config.path_url(to); let source = format!("{}/{}", self.config.bucket, encode_path(from)); @@ -492,7 +499,7 @@ impl S3Client { } pub async fn create_multipart(&self, location: &Path) -> Result { - let credential = self.get_credential().await?; + let credential = self.config.get_credential().await?; let url = format!("{}?uploads=", self.config.path_url(location),); let response = self @@ -512,7 +519,7 @@ impl S3Client { .await .context(CreateMultipartResponseBodySnafu)?; - let response: InitiateMultipart = + let response: InitiateMultipartUploadResult = quick_xml::de::from_reader(response.reader()).context(InvalidMultipartResponseSnafu)?; Ok(response.upload_id) @@ -527,15 +534,15 @@ impl S3Client { ) -> Result { let part = (part_idx + 1).to_string(); - let content_id = self - .put_request( - path, - data, - &[("partNumber", &part), ("uploadId", upload_id)], - ) + let result = self + .put_request(path, data) + .query(&[("partNumber", &part), ("uploadId", upload_id)]) + .send() .await?; - Ok(PartId { content_id }) + Ok(PartId { + content_id: result.e_tag.unwrap(), + }) } pub async fn complete_multipart( @@ -544,19 +551,10 @@ impl S3Client { upload_id: &str, parts: Vec, ) -> Result { - let parts = parts - .into_iter() - .enumerate() - .map(|(part_idx, part)| MultipartPart { - e_tag: part.content_id, - part_number: part_idx + 1, - }) - .collect(); - - let request = CompleteMultipart { part: parts }; + let request = CompleteMultipartUpload::from(parts); let body = quick_xml::se::to_string(&request).unwrap(); - let credential = self.get_credential().await?; + let credential = self.config.get_credential().await?; let url = self.config.path_url(location); let response = self @@ -575,16 +573,19 @@ impl S3Client { .await .context(CompleteMultipartRequestSnafu)?; + let version = get_version(response.headers(), VERSION_HEADER).context(MetadataSnafu)?; + let data = response .bytes() .await .context(CompleteMultipartResponseBodySnafu)?; - let response: CompleteMultipartResult = + let response: CompleteMultipartUploadResult = quick_xml::de::from_reader(data.reader()).context(InvalidMultipartResponseSnafu)?; Ok(PutResult { e_tag: Some(response.e_tag), + version, }) } } @@ -596,12 +597,12 @@ impl GetClient for S3Client { const HEADER_CONFIG: HeaderConfig = HeaderConfig { etag_required: false, last_modified_required: false, - version_header: Some("x-amz-version-id"), + version_header: Some(VERSION_HEADER), }; /// Make an S3 GET request async fn get_request(&self, path: &Path, options: GetOptions) -> Result { - let credential = self.get_credential().await?; + let credential = self.config.get_credential().await?; let url = self.config.path_url(path); let method = match options.head { true => Method::HEAD, @@ -643,7 +644,7 @@ impl ListClient for S3Client { token: Option<&str>, offset: Option<&str>, ) -> Result<(ListResult, Option)> { - let credential = self.get_credential().await?; + let credential = self.config.get_credential().await?; let url = self.config.bucket_endpoint.clone(); let mut query = Vec::with_capacity(4); diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index 57254c7cf4e8..99e637695059 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -35,6 +35,7 @@ use async_trait::async_trait; use bytes::Bytes; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; +use reqwest::header::{IF_MATCH, IF_NONE_MATCH}; use reqwest::Method; use std::{sync::Arc, time::Duration}; use tokio::io::AsyncWrite; @@ -47,20 +48,20 @@ use crate::client::CredentialProvider; use crate::multipart::{MultiPartStore, PartId, PutPart, WriteMultiPart}; use crate::signer::Signer; use crate::{ - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, PutResult, - Result, + Error, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, PutMode, + PutOptions, PutResult, Result, }; mod builder; mod checksum; mod client; -mod copy; mod credential; +mod precondition; mod resolve; pub use builder::{AmazonS3Builder, AmazonS3ConfigKey}; pub use checksum::Checksum; -pub use copy::S3CopyIfNotExists; +pub use precondition::{S3ConditionalPut, S3CopyIfNotExists}; pub use resolve::resolve_bucket_region; // http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html @@ -158,9 +159,33 @@ impl Signer for AmazonS3 { #[async_trait] impl ObjectStore for AmazonS3 { - async fn put(&self, location: &Path, bytes: Bytes) -> Result { - let e_tag = self.client.put_request(location, bytes, &()).await?; - Ok(PutResult { e_tag: Some(e_tag) }) + async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + let request = self.client.put_request(location, bytes); + match (opts.mode, &self.client.config().conditional_put) { + (PutMode::Overwrite, _) => request.send().await, + (PutMode::Create | PutMode::Update(_), None) => Err(Error::NotImplemented), + (PutMode::Create, Some(S3ConditionalPut::ETagMatch)) => { + match request.header(&IF_NONE_MATCH, "*").send().await { + // Technically If-None-Match should return NotModified but some stores, + // such as R2, instead return PreconditionFailed + // https://developers.cloudflare.com/r2/api/s3/extensions/#conditional-operations-in-putobject + Err(e @ Error::NotModified { .. } | e @ Error::Precondition { .. }) => { + Err(Error::AlreadyExists { + path: location.to_string(), + source: Box::new(e), + }) + } + r => r, + } + } + (PutMode::Update(v), Some(S3ConditionalPut::ETagMatch)) => { + let etag = v.e_tag.ok_or_else(|| Error::Generic { + store: STORE, + source: "ETag required for conditional put".to_string().into(), + })?; + request.header(&IF_MATCH, etag.as_str()).send().await + } + } } async fn put_multipart( @@ -306,6 +331,7 @@ mod tests { let config = integration.client.config(); let is_local = config.endpoint.starts_with("http://"); let test_not_exists = config.copy_if_not_exists.is_some(); + let test_conditional_put = config.conditional_put.is_some(); // Localstack doesn't support listing with spaces https://github.com/localstack/localstack/issues/6328 put_get_delete_list_opts(&integration, is_local).await; @@ -319,6 +345,9 @@ mod tests { if test_not_exists { copy_if_not_exists(&integration).await; } + if test_conditional_put { + put_opts(&integration, true).await; + } // run integration test with unsigned payload enabled let builder = AmazonS3Builder::from_env().with_unsigned_payload(true); diff --git a/object_store/src/aws/copy.rs b/object_store/src/aws/precondition.rs similarity index 68% rename from object_store/src/aws/copy.rs rename to object_store/src/aws/precondition.rs index da4e2809be1a..a50b57fe23f7 100644 --- a/object_store/src/aws/copy.rs +++ b/object_store/src/aws/precondition.rs @@ -17,8 +17,7 @@ use crate::config::Parse; -/// Configure how to provide [`ObjectStore::copy_if_not_exists`] for -/// [`AmazonS3`]. +/// Configure how to provide [`ObjectStore::copy_if_not_exists`] for [`AmazonS3`]. /// /// [`ObjectStore::copy_if_not_exists`]: crate::ObjectStore::copy_if_not_exists /// [`AmazonS3`]: super::AmazonS3 @@ -70,3 +69,45 @@ impl Parse for S3CopyIfNotExists { }) } } + +/// Configure how to provide conditional put support for [`AmazonS3`]. +/// +/// [`AmazonS3`]: super::AmazonS3 +#[derive(Debug, Clone)] +#[allow(missing_copy_implementations)] +#[non_exhaustive] +pub enum S3ConditionalPut { + /// Some S3-compatible stores, such as Cloudflare R2 and minio support conditional + /// put using the standard [HTTP precondition] headers If-Match and If-None-Match + /// + /// Encoded as `etag` ignoring whitespace + /// + /// [HTTP precondition]: https://datatracker.ietf.org/doc/html/rfc9110#name-preconditions + ETagMatch, +} + +impl std::fmt::Display for S3ConditionalPut { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ETagMatch => write!(f, "etag"), + } + } +} + +impl S3ConditionalPut { + fn from_str(s: &str) -> Option { + match s.trim() { + "etag" => Some(Self::ETagMatch), + _ => None, + } + } +} + +impl Parse for S3ConditionalPut { + fn parse(v: &str) -> crate::Result { + Self::from_str(v).ok_or_else(|| crate::Error::Generic { + store: "Config", + source: format!("Failed to parse \"{v}\" as S3PutConditional").into(), + }) + } +} diff --git a/object_store/src/azure/client.rs b/object_store/src/azure/client.rs index 9f47b9a8152b..c7bd79149872 100644 --- a/object_store/src/azure/client.rs +++ b/object_store/src/azure/client.rs @@ -19,7 +19,7 @@ use super::credential::AzureCredential; use crate::azure::credential::*; use crate::azure::{AzureCredentialProvider, STORE}; use crate::client::get::GetClient; -use crate::client::header::{get_etag, HeaderConfig}; +use crate::client::header::{get_put_result, HeaderConfig}; use crate::client::list::ListClient; use crate::client::retry::RetryExt; use crate::client::GetOptionsExt; @@ -27,25 +27,29 @@ use crate::multipart::PartId; use crate::path::DELIMITER; use crate::util::deserialize_rfc1123; use crate::{ - ClientOptions, GetOptions, ListResult, ObjectMeta, Path, PutResult, Result, RetryConfig, + ClientOptions, GetOptions, ListResult, ObjectMeta, Path, PutMode, PutOptions, PutResult, + Result, RetryConfig, }; use async_trait::async_trait; use base64::prelude::BASE64_STANDARD; use base64::Engine; use bytes::{Buf, Bytes}; use chrono::{DateTime, Utc}; +use hyper::http::HeaderName; use itertools::Itertools; use reqwest::header::CONTENT_TYPE; use reqwest::{ - header::{HeaderValue, CONTENT_LENGTH, IF_NONE_MATCH}, - Client as ReqwestClient, Method, Response, StatusCode, + header::{HeaderValue, CONTENT_LENGTH, IF_MATCH, IF_NONE_MATCH}, + Client as ReqwestClient, Method, RequestBuilder, Response, }; use serde::{Deserialize, Serialize}; -use snafu::{ResultExt, Snafu}; +use snafu::{OptionExt, ResultExt, Snafu}; use std::collections::HashMap; use std::sync::Arc; use url::Url; +const VERSION_HEADER: &str = "x-ms-version-id"; + /// A specialized `Error` for object store-related errors #[derive(Debug, Snafu)] #[allow(missing_docs)] @@ -92,6 +96,9 @@ pub(crate) enum Error { Metadata { source: crate::client::header::Error, }, + + #[snafu(display("ETag required for conditional update"))] + MissingETag, } impl From for crate::Error { @@ -134,6 +141,39 @@ impl AzureConfig { } } +/// A builder for a put request allowing customisation of the headers and query string +struct PutRequest<'a> { + path: &'a Path, + config: &'a AzureConfig, + builder: RequestBuilder, +} + +impl<'a> PutRequest<'a> { + fn header(self, k: &HeaderName, v: &str) -> Self { + let builder = self.builder.header(k, v); + Self { builder, ..self } + } + + fn query(self, query: &T) -> Self { + let builder = self.builder.query(query); + Self { builder, ..self } + } + + async fn send(self) -> Result { + let credential = self.config.credentials.get_credential().await?; + let response = self + .builder + .with_azure_authorization(&credential, &self.config.account) + .send_retry(&self.config.retry_config) + .await + .context(PutRequestSnafu { + path: self.path.as_ref(), + })?; + + Ok(response) + } +} + #[derive(Debug)] pub(crate) struct AzureClient { config: AzureConfig, @@ -156,63 +196,52 @@ impl AzureClient { self.config.credentials.get_credential().await } - /// Make an Azure PUT request - pub async fn put_request( - &self, - path: &Path, - bytes: Option, - is_block_op: bool, - query: &T, - ) -> Result { - let credential = self.get_credential().await?; + fn put_request<'a>(&'a self, path: &'a Path, bytes: Bytes) -> PutRequest<'a> { let url = self.config.path_url(path); let mut builder = self.client.request(Method::PUT, url); - if !is_block_op { - builder = builder.header(&BLOB_TYPE, "BlockBlob").query(query); - } else { - builder = builder.query(query); - } - if let Some(value) = self.config().client_options.get_content_type(path) { builder = builder.header(CONTENT_TYPE, value); } - if let Some(bytes) = bytes { - builder = builder - .header(CONTENT_LENGTH, HeaderValue::from(bytes.len())) - .body(bytes) - } else { - builder = builder.header(CONTENT_LENGTH, HeaderValue::from_static("0")); + builder = builder + .header(CONTENT_LENGTH, HeaderValue::from(bytes.len())) + .body(bytes); + + PutRequest { + path, + builder, + config: &self.config, } + } - let response = builder - .with_azure_authorization(&credential, &self.config.account) - .send_retry(&self.config.retry_config) - .await - .context(PutRequestSnafu { - path: path.as_ref(), - })?; + /// Make an Azure PUT request + pub async fn put_blob(&self, path: &Path, bytes: Bytes, opts: PutOptions) -> Result { + let builder = self.put_request(path, bytes); + + let builder = match &opts.mode { + PutMode::Overwrite => builder, + PutMode::Create => builder.header(&IF_NONE_MATCH, "*"), + PutMode::Update(v) => { + let etag = v.e_tag.as_ref().context(MissingETagSnafu)?; + builder.header(&IF_MATCH, etag) + } + }; - Ok(response) + let response = builder.header(&BLOB_TYPE, "BlockBlob").send().await?; + Ok(get_put_result(response.headers(), VERSION_HEADER).context(MetadataSnafu)?) } /// PUT a block pub async fn put_block(&self, path: &Path, part_idx: usize, data: Bytes) -> Result { let content_id = format!("{part_idx:20}"); - let block_id: BlockId = content_id.clone().into(); + let block_id = BASE64_STANDARD.encode(&content_id); - self.put_request( - path, - Some(data), - true, - &[ - ("comp", "block"), - ("blockid", &BASE64_STANDARD.encode(block_id)), - ], - ) - .await?; + self.put_request(path, data) + .query(&[("comp", "block"), ("blockid", &block_id)]) + .send() + .await?; Ok(PartId { content_id }) } @@ -224,15 +253,13 @@ impl AzureClient { .map(|part| BlockId::from(part.content_id)) .collect(); - let block_list = BlockList { blocks }; - let block_xml = block_list.to_xml(); - let response = self - .put_request(path, Some(block_xml.into()), true, &[("comp", "blocklist")]) + .put_request(path, BlockList { blocks }.to_xml().into()) + .query(&[("comp", "blocklist")]) + .send() .await?; - let e_tag = get_etag(response.headers()).context(MetadataSnafu)?; - Ok(PutResult { e_tag: Some(e_tag) }) + Ok(get_put_result(response.headers(), VERSION_HEADER).context(MetadataSnafu)?) } /// Make an Azure Delete request @@ -284,13 +311,7 @@ impl AzureClient { .with_azure_authorization(&credential, &self.config.account) .send_retry(&self.config.retry_config) .await - .map_err(|err| match err.status() { - Some(StatusCode::CONFLICT) => crate::Error::AlreadyExists { - source: Box::new(err), - path: to.to_string(), - }, - _ => err.error(STORE, from.to_string()), - })?; + .map_err(|err| err.error(STORE, from.to_string()))?; Ok(()) } @@ -303,7 +324,7 @@ impl GetClient for AzureClient { const HEADER_CONFIG: HeaderConfig = HeaderConfig { etag_required: true, last_modified_required: true, - version_header: Some("x-ms-version-id"), + version_header: Some(VERSION_HEADER), }; /// Make an Azure GET request diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index 779ac2f71ff8..762a51dd9d60 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -29,7 +29,8 @@ use crate::{ multipart::{PartId, PutPart, WriteMultiPart}, path::Path, - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutResult, Result, + GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, PutResult, + Result, }; use async_trait::async_trait; use bytes::Bytes; @@ -49,7 +50,6 @@ mod credential; /// [`CredentialProvider`] for [`MicrosoftAzure`] pub type AzureCredentialProvider = Arc>; -use crate::client::header::get_etag; use crate::multipart::MultiPartStore; pub use builder::{AzureConfigKey, MicrosoftAzureBuilder}; pub use credential::AzureCredential; @@ -82,16 +82,8 @@ impl std::fmt::Display for MicrosoftAzure { #[async_trait] impl ObjectStore for MicrosoftAzure { - async fn put(&self, location: &Path, bytes: Bytes) -> Result { - let response = self - .client - .put_request(location, Some(bytes), false, &()) - .await?; - let e_tag = get_etag(response.headers()).map_err(|e| crate::Error::Generic { - store: STORE, - source: Box::new(e), - })?; - Ok(PutResult { e_tag: Some(e_tag) }) + async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + self.client.put_blob(location, bytes, opts).await } async fn put_multipart( @@ -208,6 +200,7 @@ mod tests { rename_and_copy(&integration).await; copy_if_not_exists(&integration).await; stream_get(&integration).await; + put_opts(&integration, true).await; multipart(&integration, &integration).await; } diff --git a/object_store/src/chunked.rs b/object_store/src/chunked.rs index 021f9f50156b..d33556f4b12e 100644 --- a/object_store/src/chunked.rs +++ b/object_store/src/chunked.rs @@ -29,7 +29,8 @@ use tokio::io::AsyncWrite; use crate::path::Path; use crate::{ - GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, PutResult, + GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, PutOptions, + PutResult, }; use crate::{MultipartId, Result}; @@ -62,8 +63,8 @@ impl Display for ChunkedStore { #[async_trait] impl ObjectStore for ChunkedStore { - async fn put(&self, location: &Path, bytes: Bytes) -> Result { - self.inner.put(location, bytes).await + async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + self.inner.put_opts(location, bytes, opts).await } async fn put_multipart( diff --git a/object_store/src/client/header.rs b/object_store/src/client/header.rs index e67496833b99..e85bf6ba52d0 100644 --- a/object_store/src/client/header.rs +++ b/object_store/src/client/header.rs @@ -67,6 +67,23 @@ pub enum Error { }, } +/// Extracts a PutResult from the provided [`HeaderMap`] +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +pub fn get_put_result(headers: &HeaderMap, version: &str) -> Result { + let e_tag = Some(get_etag(headers)?); + let version = get_version(headers, version)?; + Ok(crate::PutResult { e_tag, version }) +} + +/// Extracts a optional version from the provided [`HeaderMap`] +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +pub fn get_version(headers: &HeaderMap, version: &str) -> Result, Error> { + Ok(match headers.get(version) { + Some(x) => Some(x.to_str().context(BadHeaderSnafu)?.to_string()), + None => None, + }) +} + /// Extracts an etag from the provided [`HeaderMap`] pub fn get_etag(headers: &HeaderMap) -> Result { let e_tag = headers.get(ETAG).ok_or(Error::MissingEtag)?; diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs index 77eee7fc92f3..ae092edac095 100644 --- a/object_store/src/client/mod.rs +++ b/object_store/src/client/mod.rs @@ -38,7 +38,7 @@ pub mod token; pub mod header; #[cfg(any(feature = "aws", feature = "gcp"))] -pub mod list_response; +pub mod s3; use async_trait::async_trait; use std::collections::HashMap; diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs index d70d6d88de32..789103c0f74f 100644 --- a/object_store/src/client/retry.rs +++ b/object_store/src/client/retry.rs @@ -79,6 +79,10 @@ impl Error { path, source: Box::new(self), }, + Some(StatusCode::CONFLICT) => crate::Error::AlreadyExists { + path, + source: Box::new(self), + }, _ => crate::Error::Generic { store, source: Box::new(self), diff --git a/object_store/src/client/list_response.rs b/object_store/src/client/s3.rs similarity index 68% rename from object_store/src/client/list_response.rs rename to object_store/src/client/s3.rs index 7a170c584156..61237dc4beab 100644 --- a/object_store/src/client/list_response.rs +++ b/object_store/src/client/s3.rs @@ -14,12 +14,13 @@ // specific language governing permissions and limitations // under the License. -//! The list response format used by GCP and AWS +//! The list and multipart API used by both GCS and S3 +use crate::multipart::PartId; use crate::path::Path; use crate::{ListResult, ObjectMeta, Result}; use chrono::{DateTime, Utc}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize)] #[serde(rename_all = "PascalCase")] @@ -84,3 +85,44 @@ impl TryFrom for ObjectMeta { }) } } + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct InitiateMultipartUploadResult { + pub upload_id: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "PascalCase")] +pub struct CompleteMultipartUpload { + pub part: Vec, +} + +impl From> for CompleteMultipartUpload { + fn from(value: Vec) -> Self { + let part = value + .into_iter() + .enumerate() + .map(|(part_number, part)| MultipartPart { + e_tag: part.content_id, + part_number: part_number + 1, + }) + .collect(); + Self { part } + } +} + +#[derive(Debug, Serialize)] +pub struct MultipartPart { + #[serde(rename = "ETag")] + pub e_tag: String, + #[serde(rename = "PartNumber")] + pub part_number: usize, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct CompleteMultipartUploadResult { + #[serde(rename = "ETag")] + pub e_tag: String, +} diff --git a/object_store/src/gcp/client.rs b/object_store/src/gcp/client.rs index 8c44f9016480..78964077e2fe 100644 --- a/object_store/src/gcp/client.rs +++ b/object_store/src/gcp/client.rs @@ -16,23 +16,34 @@ // under the License. use crate::client::get::GetClient; -use crate::client::header::{get_etag, HeaderConfig}; +use crate::client::header::{get_put_result, get_version, HeaderConfig}; use crate::client::list::ListClient; -use crate::client::list_response::ListResponse; use crate::client::retry::RetryExt; +use crate::client::s3::{ + CompleteMultipartUpload, CompleteMultipartUploadResult, InitiateMultipartUploadResult, + ListResponse, +}; use crate::client::GetOptionsExt; use crate::gcp::{GcpCredential, GcpCredentialProvider, STORE}; use crate::multipart::PartId; use crate::path::{Path, DELIMITER}; -use crate::{ClientOptions, GetOptions, ListResult, MultipartId, PutResult, Result, RetryConfig}; +use crate::{ + ClientOptions, GetOptions, ListResult, MultipartId, PutMode, PutOptions, PutResult, Result, + RetryConfig, +}; use async_trait::async_trait; use bytes::{Buf, Bytes}; use percent_encoding::{percent_encode, utf8_percent_encode, NON_ALPHANUMERIC}; -use reqwest::{header, Client, Method, Response, StatusCode}; +use reqwest::header::HeaderName; +use reqwest::{header, Client, Method, RequestBuilder, Response, StatusCode}; use serde::Serialize; -use snafu::{ResultExt, Snafu}; +use snafu::{OptionExt, ResultExt, Snafu}; use std::sync::Arc; +const VERSION_HEADER: &str = "x-goog-generation"; + +static VERSION_MATCH: HeaderName = HeaderName::from_static("x-goog-if-generation-match"); + #[derive(Debug, Snafu)] enum Error { #[snafu(display("Error performing list request: {}", source))] @@ -78,6 +89,18 @@ enum Error { Metadata { source: crate::client::header::Error, }, + + #[snafu(display("Version required for conditional update"))] + MissingVersion, + + #[snafu(display("Error performing complete multipart request: {}", source))] + CompleteMultipartRequest { source: crate::client::retry::Error }, + + #[snafu(display("Error getting complete multipart response body: {}", source))] + CompleteMultipartResponseBody { source: reqwest::Error }, + + #[snafu(display("Got invalid multipart response: {}", source))] + InvalidMultipartResponse { source: quick_xml::de::DeError }, } impl From for crate::Error { @@ -107,6 +130,39 @@ pub struct GoogleCloudStorageConfig { pub client_options: ClientOptions, } +/// A builder for a put request allowing customisation of the headers and query string +pub struct PutRequest<'a> { + path: &'a Path, + config: &'a GoogleCloudStorageConfig, + builder: RequestBuilder, +} + +impl<'a> PutRequest<'a> { + fn header(self, k: &HeaderName, v: &str) -> Self { + let builder = self.builder.header(k, v); + Self { builder, ..self } + } + + fn query(self, query: &T) -> Self { + let builder = self.builder.query(query); + Self { builder, ..self } + } + + async fn send(self) -> Result { + let credential = self.config.credentials.get_credential().await?; + let response = self + .builder + .bearer_auth(&credential.bearer) + .send_retry(&self.config.retry_config) + .await + .context(PutRequestSnafu { + path: self.path.as_ref(), + })?; + + Ok(get_put_result(response.headers(), VERSION_HEADER).context(MetadataSnafu)?) + } +} + #[derive(Debug)] pub struct GoogleCloudStorageClient { config: GoogleCloudStorageConfig, @@ -152,13 +208,7 @@ impl GoogleCloudStorageClient { /// Perform a put request /// /// Returns the new ETag - pub async fn put_request( - &self, - path: &Path, - payload: Bytes, - query: &T, - ) -> Result { - let credential = self.get_credential().await?; + pub fn put_request<'a>(&'a self, path: &'a Path, payload: Bytes) -> PutRequest<'a> { let url = self.object_url(path); let content_type = self @@ -167,21 +217,38 @@ impl GoogleCloudStorageClient { .get_content_type(path) .unwrap_or("application/octet-stream"); - let response = self + let builder = self .client .request(Method::PUT, url) - .query(query) - .bearer_auth(&credential.bearer) .header(header::CONTENT_TYPE, content_type) .header(header::CONTENT_LENGTH, payload.len()) - .body(payload) - .send_retry(&self.config.retry_config) - .await - .context(PutRequestSnafu { - path: path.as_ref(), - })?; + .body(payload); - Ok(get_etag(response.headers()).context(MetadataSnafu)?) + PutRequest { + path, + builder, + config: &self.config, + } + } + + pub async fn put(&self, path: &Path, data: Bytes, opts: PutOptions) -> Result { + let builder = self.put_request(path, data); + + let builder = match &opts.mode { + PutMode::Overwrite => builder, + PutMode::Create => builder.header(&VERSION_MATCH, "0"), + PutMode::Update(v) => { + let etag = v.version.as_ref().context(MissingVersionSnafu)?; + builder.header(&VERSION_MATCH, etag) + } + }; + + match (opts.mode, builder.send().await) { + (PutMode::Create, Err(crate::Error::Precondition { path, source })) => { + Err(crate::Error::AlreadyExists { path, source }) + } + (_, r) => r, + } } /// Perform a put part request @@ -194,18 +261,15 @@ impl GoogleCloudStorageClient { part_idx: usize, data: Bytes, ) -> Result { - let content_id = self - .put_request( - path, - data, - &[ - ("partNumber", &format!("{}", part_idx + 1)), - ("uploadId", upload_id), - ], - ) - .await?; - - Ok(PartId { content_id }) + let query = &[ + ("partNumber", &format!("{}", part_idx + 1)), + ("uploadId", upload_id), + ]; + let result = self.put_request(path, data).query(query).send().await?; + + Ok(PartId { + content_id: result.e_tag.unwrap(), + }) } /// Initiate a multi-part upload @@ -268,17 +332,8 @@ impl GoogleCloudStorageClient { let upload_id = multipart_id.clone(); let url = self.object_url(path); - let parts = completed_parts - .into_iter() - .enumerate() - .map(|(part_number, part)| MultipartPart { - e_tag: part.content_id, - part_number: part_number + 1, - }) - .collect(); - + let upload_info = CompleteMultipartUpload::from(completed_parts); let credential = self.get_credential().await?; - let upload_info = CompleteMultipartUpload { parts }; let data = quick_xml::se::to_string(&upload_info) .context(InvalidPutResponseSnafu)? @@ -287,7 +342,7 @@ impl GoogleCloudStorageClient { // https://github.com/tafia/quick-xml/issues/350 .replace(""", "\""); - let result = self + let response = self .client .request(Method::POST, &url) .bearer_auth(&credential.bearer) @@ -295,12 +350,22 @@ impl GoogleCloudStorageClient { .body(data) .send_retry(&self.config.retry_config) .await - .context(PostRequestSnafu { - path: path.as_ref(), - })?; + .context(CompleteMultipartRequestSnafu)?; - let etag = get_etag(result.headers()).context(MetadataSnafu)?; - Ok(PutResult { e_tag: Some(etag) }) + let version = get_version(response.headers(), VERSION_HEADER).context(MetadataSnafu)?; + + let data = response + .bytes() + .await + .context(CompleteMultipartResponseBodySnafu)?; + + let response: CompleteMultipartUploadResult = + quick_xml::de::from_reader(data.reader()).context(InvalidMultipartResponseSnafu)?; + + Ok(PutResult { + e_tag: Some(response.e_tag), + version, + }) } /// Perform a delete request @@ -334,7 +399,7 @@ impl GoogleCloudStorageClient { .header("x-goog-copy-source", source); if if_not_exists { - builder = builder.header("x-goog-if-generation-match", 0); + builder = builder.header(&VERSION_MATCH, 0); } builder @@ -362,7 +427,7 @@ impl GetClient for GoogleCloudStorageClient { const HEADER_CONFIG: HeaderConfig = HeaderConfig { etag_required: true, last_modified_required: true, - version_header: Some("x-goog-generation"), + version_header: Some(VERSION_HEADER), }; /// Perform a get request @@ -375,13 +440,18 @@ impl GetClient for GoogleCloudStorageClient { false => Method::GET, }; - let mut request = self.client.request(method, url).with_get_options(options); + let mut request = self.client.request(method, url); + + if let Some(version) = &options.version { + request = request.query(&[("generation", version)]); + } if !credential.bearer.is_empty() { request = request.bearer_auth(&credential.bearer); } let response = request + .with_get_options(options) .send_retry(&self.config.retry_config) .await .context(GetRequestSnafu { @@ -444,24 +514,3 @@ impl ListClient for GoogleCloudStorageClient { Ok((response.try_into()?, token)) } } - -#[derive(serde::Deserialize, Debug)] -#[serde(rename_all = "PascalCase")] -struct InitiateMultipartUploadResult { - upload_id: String, -} - -#[derive(serde::Serialize, Debug)] -#[serde(rename_all = "PascalCase", rename(serialize = "Part"))] -struct MultipartPart { - #[serde(rename = "PartNumber")] - part_number: usize, - e_tag: String, -} - -#[derive(serde::Serialize, Debug)] -#[serde(rename_all = "PascalCase")] -struct CompleteMultipartUpload { - #[serde(rename = "Part", default)] - parts: Vec, -} diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index 0eb3e9c23c43..7721b1278a80 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -35,7 +35,8 @@ use crate::client::CredentialProvider; use crate::{ multipart::{PartId, PutPart, WriteMultiPart}, path::Path, - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutResult, Result, + GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, PutResult, + Result, }; use async_trait::async_trait; use bytes::Bytes; @@ -107,9 +108,8 @@ impl PutPart for GCSMultipartUpload { #[async_trait] impl ObjectStore for GoogleCloudStorage { - async fn put(&self, location: &Path, bytes: Bytes) -> Result { - let e_tag = self.client.put_request(location, bytes, &()).await?; - Ok(PutResult { e_tag: Some(e_tag) }) + async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + self.client.put(location, bytes, opts).await } async fn put_multipart( @@ -221,6 +221,7 @@ mod test { multipart(&integration, &integration).await; // Fake GCS server doesn't currently honor preconditions get_opts(&integration).await; + put_opts(&integration, true).await; } } diff --git a/object_store/src/http/client.rs b/object_store/src/http/client.rs index a7dbdfcbe844..8700775fb243 100644 --- a/object_store/src/http/client.rs +++ b/object_store/src/http/client.rs @@ -243,6 +243,10 @@ impl Client { .header("Destination", self.path_url(to).as_str()); if !overwrite { + // While the Overwrite header appears to duplicate + // the functionality of the If-Match: * header of HTTP/1.1, If-Match + // applies only to the Request-URI, and not to the Destination of a COPY + // or MOVE. builder = builder.header("Overwrite", "F"); } diff --git a/object_store/src/http/mod.rs b/object_store/src/http/mod.rs index 8f61011ccae1..cfcde27fd781 100644 --- a/object_store/src/http/mod.rs +++ b/object_store/src/http/mod.rs @@ -46,7 +46,7 @@ use crate::http::client::Client; use crate::path::Path; use crate::{ ClientConfigKey, ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, - ObjectStore, PutResult, Result, RetryConfig, + ObjectStore, PutMode, PutOptions, PutResult, Result, RetryConfig, }; mod client; @@ -96,14 +96,23 @@ impl std::fmt::Display for HttpStore { #[async_trait] impl ObjectStore for HttpStore { - async fn put(&self, location: &Path, bytes: Bytes) -> Result { + async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + if opts.mode != PutMode::Overwrite { + // TODO: Add support for If header - https://datatracker.ietf.org/doc/html/rfc2518#section-9.4 + return Err(crate::Error::NotImplemented); + } + let response = self.client.put(location, bytes).await?; let e_tag = match get_etag(response.headers()) { Ok(e_tag) => Some(e_tag), Err(crate::client::header::Error::MissingEtag) => None, Err(source) => return Err(Error::Metadata { source }.into()), }; - Ok(PutResult { e_tag }) + + Ok(PutResult { + e_tag, + version: None, + }) } async fn put_multipart( diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 9a0667229803..66964304e853 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -299,7 +299,12 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// The operation is guaranteed to be atomic, it will either successfully /// write the entirety of `bytes` to `location`, or fail. No clients /// should be able to observe a partially written object - async fn put(&self, location: &Path, bytes: Bytes) -> Result; + async fn put(&self, location: &Path, bytes: Bytes) -> Result { + self.put_opts(location, bytes, PutOptions::default()).await + } + + /// Save the provided bytes to the specified location with the given options + async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result; /// Get a multi-part upload that allows writing data in chunks. /// @@ -531,6 +536,15 @@ macro_rules! as_ref_impl { self.as_ref().put(location, bytes).await } + async fn put_opts( + &self, + location: &Path, + bytes: Bytes, + opts: PutOptions, + ) -> Result { + self.as_ref().put_opts(location, bytes, opts).await + } + async fn put_multipart( &self, location: &Path, @@ -837,13 +851,65 @@ impl GetResult { } } +/// Configure preconditions for the put operation +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum PutMode { + /// Perform an atomic write operation, overwriting any object present at the provided path + #[default] + Overwrite, + /// Perform an atomic write operation, returning [`Error::AlreadyExists`] if an + /// object already exists at the provided path + Create, + /// Perform an atomic write operation if the current version of the object matches the + /// provided [`UpdateVersion`], returning [`Error::Precondition`] otherwise + Update(UpdateVersion), +} + +/// Uniquely identifies a version of an object to update +/// +/// Stores will use differing combinations of `e_tag` and `version` to provide conditional +/// updates, and it is therefore recommended applications preserve both +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UpdateVersion { + /// The unique identifier for the newly created object + /// + /// + pub e_tag: Option, + /// A version indicator for the newly created object + pub version: Option, +} + +impl From for UpdateVersion { + fn from(value: PutResult) -> Self { + Self { + e_tag: value.e_tag, + version: value.version, + } + } +} + +/// Options for a put request +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct PutOptions { + /// Configure the [`PutMode`] for this operation + pub mode: PutMode, +} + +impl From for PutOptions { + fn from(mode: PutMode) -> Self { + Self { mode } + } +} + /// Result for a put request #[derive(Debug, Clone, PartialEq, Eq)] pub struct PutResult { - /// The unique identifier for the object + /// The unique identifier for the newly created object /// /// pub e_tag: Option, + /// A version indicator for the newly created object + pub version: Option, } /// A specialized `Result` for object store-related errors @@ -947,6 +1013,7 @@ mod tests { use crate::multipart::MultiPartStore; use crate::test_util::flatten_list_stream; use chrono::TimeZone; + use futures::stream::FuturesUnordered; use rand::{thread_rng, Rng}; use tokio::io::AsyncWriteExt; @@ -1406,7 +1473,7 @@ mod tests { // Can retrieve previous version let get_opts = storage.get_opts(&path, options).await.unwrap(); let old = get_opts.bytes().await.unwrap(); - assert_eq!(old, b"foo".as_slice()); + assert_eq!(old, b"test".as_slice()); // Current version contains the updated data let current = storage.get(&path).await.unwrap().bytes().await.unwrap(); @@ -1414,6 +1481,104 @@ mod tests { } } + pub(crate) async fn put_opts(storage: &dyn ObjectStore, supports_update: bool) { + delete_fixtures(storage).await; + let path = Path::from("put_opts"); + let v1 = storage + .put_opts(&path, "a".into(), PutMode::Create.into()) + .await + .unwrap(); + + let err = storage + .put_opts(&path, "b".into(), PutMode::Create.into()) + .await + .unwrap_err(); + assert!(matches!(err, Error::AlreadyExists { .. }), "{err}"); + + let b = storage.get(&path).await.unwrap().bytes().await.unwrap(); + assert_eq!(b.as_ref(), b"a"); + + if !supports_update { + return; + } + + let v2 = storage + .put_opts(&path, "c".into(), PutMode::Update(v1.clone().into()).into()) + .await + .unwrap(); + + let b = storage.get(&path).await.unwrap().bytes().await.unwrap(); + assert_eq!(b.as_ref(), b"c"); + + let err = storage + .put_opts(&path, "d".into(), PutMode::Update(v1.into()).into()) + .await + .unwrap_err(); + assert!(matches!(err, Error::Precondition { .. }), "{err}"); + + storage + .put_opts(&path, "e".into(), PutMode::Update(v2.clone().into()).into()) + .await + .unwrap(); + + let b = storage.get(&path).await.unwrap().bytes().await.unwrap(); + assert_eq!(b.as_ref(), b"e"); + + // Update not exists + let path = Path::from("I don't exist"); + let err = storage + .put_opts(&path, "e".into(), PutMode::Update(v2.into()).into()) + .await + .unwrap_err(); + assert!(matches!(err, Error::Precondition { .. }), "{err}"); + + const NUM_WORKERS: usize = 5; + const NUM_INCREMENTS: usize = 10; + + let path = Path::from("RACE"); + let mut futures: FuturesUnordered<_> = (0..NUM_WORKERS) + .map(|_| async { + for _ in 0..NUM_INCREMENTS { + loop { + match storage.get(&path).await { + Ok(r) => { + let mode = PutMode::Update(UpdateVersion { + e_tag: r.meta.e_tag.clone(), + version: r.meta.version.clone(), + }); + + let b = r.bytes().await.unwrap(); + let v: usize = std::str::from_utf8(&b).unwrap().parse().unwrap(); + let new = (v + 1).to_string(); + + match storage.put_opts(&path, new.into(), mode.into()).await { + Ok(_) => break, + Err(Error::Precondition { .. }) => continue, + Err(e) => return Err(e), + } + } + Err(Error::NotFound { .. }) => { + let mode = PutMode::Create; + match storage.put_opts(&path, "1".into(), mode.into()).await { + Ok(_) => break, + Err(Error::AlreadyExists { .. }) => continue, + Err(e) => return Err(e), + } + } + Err(e) => return Err(e), + } + } + } + Ok(()) + }) + .collect(); + + while futures.next().await.transpose().unwrap().is_some() {} + let b = storage.get(&path).await.unwrap().bytes().await.unwrap(); + let v = std::str::from_utf8(&b).unwrap().parse::().unwrap(); + assert_eq!(v, NUM_WORKERS * NUM_INCREMENTS); + } + /// Returns a chunk of length `chunk_length` fn get_chunk(chunk_length: usize) -> Bytes { let mut data = vec![0_u8; chunk_length]; diff --git a/object_store/src/limit.rs b/object_store/src/limit.rs index cd01a964dc3e..39cc605c4768 100644 --- a/object_store/src/limit.rs +++ b/object_store/src/limit.rs @@ -19,7 +19,7 @@ use crate::{ BoxStream, GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, ObjectMeta, - ObjectStore, Path, PutResult, Result, StreamExt, + ObjectStore, Path, PutOptions, PutResult, Result, StreamExt, }; use async_trait::async_trait; use bytes::Bytes; @@ -77,6 +77,10 @@ impl ObjectStore for LimitStore { self.inner.put(location, bytes).await } + async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.put_opts(location, bytes, opts).await + } async fn put_multipart( &self, location: &Path, diff --git a/object_store/src/local.rs b/object_store/src/local.rs index ce9aa4683499..919baf71b0a8 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -20,7 +20,7 @@ use crate::{ maybe_spawn_blocking, path::{absolute_path_to_url, Path}, GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, ObjectMeta, ObjectStore, - PutResult, Result, + PutMode, PutOptions, PutResult, Result, }; use async_trait::async_trait; use bytes::Bytes; @@ -271,20 +271,44 @@ impl Config { #[async_trait] impl ObjectStore for LocalFileSystem { - async fn put(&self, location: &Path, bytes: Bytes) -> Result { + async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + if matches!(opts.mode, PutMode::Update(_)) { + return Err(crate::Error::NotImplemented); + } + let path = self.config.path_to_filesystem(location)?; maybe_spawn_blocking(move || { let (mut file, suffix) = new_staged_upload(&path)?; let staging_path = staged_upload_path(&path, &suffix); - file.write_all(&bytes) - .context(UnableToCopyDataToFileSnafu) - .and_then(|_| { - std::fs::rename(&staging_path, &path).context(UnableToRenameFileSnafu) - }) - .map_err(|e| { - let _ = std::fs::remove_file(&staging_path); // Attempt to cleanup - e - })?; + + let err = match file.write_all(&bytes) { + Ok(_) => match opts.mode { + PutMode::Overwrite => match std::fs::rename(&staging_path, &path) { + Ok(_) => None, + Err(source) => Some(Error::UnableToRenameFile { source }), + }, + PutMode::Create => match std::fs::hard_link(&staging_path, &path) { + Ok(_) => { + let _ = std::fs::remove_file(&staging_path); // Attempt to cleanup + None + } + Err(source) => match source.kind() { + ErrorKind::AlreadyExists => Some(Error::AlreadyExists { + path: path.to_str().unwrap().to_string(), + source, + }), + _ => Some(Error::UnableToRenameFile { source }), + }, + }, + PutMode::Update(_) => unreachable!(), + }, + Err(source) => Some(Error::UnableToCopyDataToFile { source }), + }; + + if let Some(err) = err { + let _ = std::fs::remove_file(&staging_path); // Attempt to cleanup + return Err(err.into()); + } let metadata = file.metadata().map_err(|e| Error::Metadata { source: e.into(), @@ -293,6 +317,7 @@ impl ObjectStore for LocalFileSystem { Ok(PutResult { e_tag: Some(get_etag(&metadata)), + version: None, }) }) .await @@ -1054,6 +1079,7 @@ mod tests { rename_and_copy(&integration).await; copy_if_not_exists(&integration).await; stream_get(&integration).await; + put_opts(&integration, false).await; } #[test] diff --git a/object_store/src/memory.rs b/object_store/src/memory.rs index 8b9522e48de8..9d79a798ad1f 100644 --- a/object_store/src/memory.rs +++ b/object_store/src/memory.rs @@ -17,7 +17,8 @@ //! An in-memory object store implementation use crate::{ - path::Path, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, PutResult, Result, + path::Path, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, PutMode, + PutOptions, PutResult, Result, UpdateVersion, }; use crate::{GetOptions, MultipartId}; use async_trait::async_trait; @@ -52,6 +53,9 @@ enum Error { #[snafu(display("Object already exists at that location: {path}"))] AlreadyExists { path: String }, + + #[snafu(display("ETag required for conditional update"))] + MissingETag, } impl From for super::Error { @@ -110,9 +114,50 @@ impl Storage { let etag = self.next_etag; self.next_etag += 1; let entry = Entry::new(bytes, Utc::now(), etag); - self.map.insert(location.clone(), entry); + self.overwrite(location, entry); etag } + + fn overwrite(&mut self, location: &Path, entry: Entry) { + self.map.insert(location.clone(), entry); + } + + fn create(&mut self, location: &Path, entry: Entry) -> Result<()> { + use std::collections::btree_map; + match self.map.entry(location.clone()) { + btree_map::Entry::Occupied(_) => Err(Error::AlreadyExists { + path: location.to_string(), + } + .into()), + btree_map::Entry::Vacant(v) => { + v.insert(entry); + Ok(()) + } + } + } + + fn update(&mut self, location: &Path, v: UpdateVersion, entry: Entry) -> Result<()> { + match self.map.get_mut(location) { + // Return Precondition instead of NotFound for consistency with stores + None => Err(crate::Error::Precondition { + path: location.to_string(), + source: format!("Object at location {location} not found").into(), + }), + Some(e) => { + let existing = e.e_tag.to_string(); + let expected = v.e_tag.context(MissingETagSnafu)?; + if existing == expected { + *e = entry; + Ok(()) + } else { + Err(crate::Error::Precondition { + path: location.to_string(), + source: format!("{existing} does not match {expected}").into(), + }) + } + } + } + } } impl std::fmt::Display for InMemory { @@ -123,10 +168,21 @@ impl std::fmt::Display for InMemory { #[async_trait] impl ObjectStore for InMemory { - async fn put(&self, location: &Path, bytes: Bytes) -> Result { - let etag = self.storage.write().insert(location, bytes); + async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + let mut storage = self.storage.write(); + let etag = storage.next_etag; + let entry = Entry::new(bytes, Utc::now(), etag); + + match opts.mode { + PutMode::Overwrite => storage.overwrite(location, entry), + PutMode::Create => storage.create(location, entry)?, + PutMode::Update(v) => storage.update(location, v, entry)?, + } + storage.next_etag += 1; + Ok(PutResult { e_tag: Some(etag.to_string()), + version: None, }) } @@ -425,7 +481,7 @@ impl AsyncWrite for InMemoryAppend { fn poll_shutdown( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { self.poll_flush(cx) } } @@ -449,6 +505,7 @@ mod tests { rename_and_copy(&integration).await; copy_if_not_exists(&integration).await; stream_get(&integration).await; + put_opts(&integration, true).await; } #[tokio::test] diff --git a/object_store/src/prefix.rs b/object_store/src/prefix.rs index b5bff8b12dd7..68101307fbdf 100644 --- a/object_store/src/prefix.rs +++ b/object_store/src/prefix.rs @@ -23,7 +23,8 @@ use tokio::io::AsyncWrite; use crate::path::Path; use crate::{ - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutResult, Result, + GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutOptions, PutResult, + Result, }; #[doc(hidden)] @@ -85,6 +86,11 @@ impl ObjectStore for PrefixStore { self.inner.put(&full_path, bytes).await } + async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + let full_path = self.full_path(location); + self.inner.put_opts(&full_path, bytes, opts).await + } + async fn put_multipart( &self, location: &Path, diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs index c5521256b8a6..dcd2c04bcf05 100644 --- a/object_store/src/throttle.rs +++ b/object_store/src/throttle.rs @@ -21,7 +21,8 @@ use std::ops::Range; use std::{convert::TryInto, sync::Arc}; use crate::{ - path::Path, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, PutResult, Result, + path::Path, GetResult, GetResultPayload, ListResult, ObjectMeta, ObjectStore, PutOptions, + PutResult, Result, }; use crate::{GetOptions, MultipartId}; use async_trait::async_trait; @@ -149,10 +150,14 @@ impl std::fmt::Display for ThrottledStore { impl ObjectStore for ThrottledStore { async fn put(&self, location: &Path, bytes: Bytes) -> Result { sleep(self.config().wait_put_per_call).await; - self.inner.put(location, bytes).await } + async fn put_opts(&self, location: &Path, bytes: Bytes, opts: PutOptions) -> Result { + sleep(self.config().wait_put_per_call).await; + self.inner.put_opts(location, bytes, opts).await + } + async fn put_multipart( &self, _location: &Path, diff --git a/object_store/tests/get_range_file.rs b/object_store/tests/get_range_file.rs index 3fa1cc7104b3..85231a5a5b9b 100644 --- a/object_store/tests/get_range_file.rs +++ b/object_store/tests/get_range_file.rs @@ -22,9 +22,7 @@ use bytes::Bytes; use futures::stream::BoxStream; use object_store::local::LocalFileSystem; use object_store::path::Path; -use object_store::{ - GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, PutResult, -}; +use object_store::*; use std::fmt::Formatter; use tempfile::tempdir; use tokio::io::AsyncWrite; @@ -40,50 +38,42 @@ impl std::fmt::Display for MyStore { #[async_trait] impl ObjectStore for MyStore { - async fn put(&self, path: &Path, data: Bytes) -> object_store::Result { - self.0.put(path, data).await + async fn put_opts(&self, path: &Path, data: Bytes, opts: PutOptions) -> Result { + self.0.put_opts(path, data, opts).await } async fn put_multipart( &self, _: &Path, - ) -> object_store::Result<(MultipartId, Box)> { + ) -> Result<(MultipartId, Box)> { todo!() } - async fn abort_multipart(&self, _: &Path, _: &MultipartId) -> object_store::Result<()> { + async fn abort_multipart(&self, _: &Path, _: &MultipartId) -> Result<()> { todo!() } - async fn get_opts( - &self, - location: &Path, - options: GetOptions, - ) -> object_store::Result { + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { self.0.get_opts(location, options).await } - async fn head(&self, _: &Path) -> object_store::Result { - todo!() - } - - async fn delete(&self, _: &Path) -> object_store::Result<()> { + async fn delete(&self, _: &Path) -> Result<()> { todo!() } - fn list(&self, _: Option<&Path>) -> BoxStream<'_, object_store::Result> { + fn list(&self, _: Option<&Path>) -> BoxStream<'_, Result> { todo!() } - async fn list_with_delimiter(&self, _: Option<&Path>) -> object_store::Result { + async fn list_with_delimiter(&self, _: Option<&Path>) -> Result { todo!() } - async fn copy(&self, _: &Path, _: &Path) -> object_store::Result<()> { + async fn copy(&self, _: &Path, _: &Path) -> Result<()> { todo!() } - async fn copy_if_not_exists(&self, _: &Path, _: &Path) -> object_store::Result<()> { + async fn copy_if_not_exists(&self, _: &Path, _: &Path) -> Result<()> { todo!() } }