diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 3a841667ff97..c7820a9278c0 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -2115,6 +2115,38 @@ mod tests { assert_eq!(meta.size, chunk_size * 2); } + pub(crate) async fn multipart_lazy(storage: Arc) { + let path = Path::from("test_multipart_lazy"); + let chunk_size = 5 * 1024 * 1024; + + let chunks = get_chunks(chunk_size, 1); + let mut w = crate::multipart::put_multipart_lazy( + Arc::clone(&storage), + path.clone(), + 10 * 1024 * 1024, + ); + for chunk in chunks { + w.write_all(&chunk).await.unwrap(); + } + w.shutdown().await.unwrap(); + + let meta = storage.head(&path).await.unwrap(); + assert_eq!(meta.size, chunk_size); + + let mut w = crate::multipart::put_multipart_lazy( + Arc::clone(&storage), + path.clone(), + 10 * 1024 * 1024, + ); + let chunks = get_chunks(chunk_size, 4); + for chunk in chunks { + w.write_all(&chunk).await.unwrap(); + } + w.shutdown().await.unwrap(); + let meta = storage.head(&path).await.unwrap(); + assert_eq!(meta.size, chunk_size * 4); + } + #[cfg(any(feature = "aws", feature = "azure"))] pub(crate) async fn tagging(storage: &dyn ObjectStore, validate: bool, get_tags: F) where diff --git a/object_store/src/local.rs b/object_store/src/local.rs index 71b96f058c79..e2653a659d15 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -1432,6 +1432,14 @@ mod tests { ); } + #[tokio::test] + async fn test_multipart_lazy() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + let store = Arc::new(integration); + crate::tests::multipart_lazy(store).await; + } + #[tokio::test] async fn filesystem_filename_with_percent() { let temp_dir = TempDir::new().unwrap(); diff --git a/object_store/src/multipart.rs b/object_store/src/multipart.rs index 1dcd5a6f4960..d3d95a324e63 100644 --- a/object_store/src/multipart.rs +++ b/object_store/src/multipart.rs @@ -24,11 +24,12 @@ use async_trait::async_trait; use bytes::Bytes; use futures::{stream::FuturesUnordered, Future, StreamExt}; +use std::task::ready; use std::{io, pin::Pin, sync::Arc, task::Poll}; use tokio::io::AsyncWrite; use crate::path::Path; -use crate::{MultipartId, PutResult, Result}; +use crate::{MultipartId, ObjectStore, PutResult, Result}; type BoxedTryFuture = Pin> + Send>>; @@ -316,3 +317,215 @@ pub trait MultiPartStore: Send + Sync + 'static { /// Aborts a multipart upload async fn abort_multipart(&self, path: &Path, id: &MultipartId) -> Result<()>; } + +/// Create a lazy multipart writer for a given [`ObjectStore`] and [`Path`]. +/// +/// A multipart upload using `ObjectStore::put_multipart` will only be created when the size exceeds `multipart_threshold`, +/// otherwise a direct PUT will be performed on shutdown. +pub fn put_multipart_lazy( + store: Arc, + path: Path, + multipart_threshold: usize, +) -> Box { + Box::new(LazyWriteMultiPart::new(store, path, multipart_threshold)) +} + +enum LazyWriteState { + /// Buffering data, not yet reached multipart threshold + Buffer(Vec), + /// Writer shutdown, putting data in progress + Put(BoxedTryFuture<()>), + /// Multipart threshold reached, creating a new multipart upload + CreateMultipart(BoxedTryFuture>, Vec), + /// Writing the buffered data from before creation of multipart upload + FlushingInitialWrite(Option>, Vec, usize), + /// Delegate to underlying multipart writer + AsyncWrite(Box), +} + +/// Wrapper around a [`ObjectStore`] and [`Path`] that implements [`AsyncWrite`] +struct LazyWriteMultiPart { + store: Arc, + path: Path, + multipart_threshold: usize, + state: LazyWriteState, +} + +impl LazyWriteMultiPart { + pub fn new(store: Arc, path: Path, multipart_threshold: usize) -> Self { + Self { + store, + path, + multipart_threshold, + state: LazyWriteState::Buffer(Vec::new()), + } + } + + fn poll_create_multipart( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match &mut self.state { + LazyWriteState::CreateMultipart(fut, buffer) => { + let writer = ready!(Pin::new(fut).poll(cx))?; + if buffer.is_empty() { + self.state = LazyWriteState::AsyncWrite(writer); + } else { + let new_buffer = std::mem::take(buffer); + self.state = LazyWriteState::FlushingInitialWrite(Some(writer), new_buffer, 0); + } + Poll::Ready(Ok(())) + } + _ => unreachable!(), + } + } + + fn do_inner_flush( + cx: &mut std::task::Context<'_>, + writer: &mut Box, + buffer: &mut Vec, + flush_offset: &mut usize, + write_len: usize, + ) -> Poll> { + let end = std::cmp::min(*flush_offset + write_len, buffer.len()); + let n = ready!(Pin::new(writer).poll_write(cx, &buffer[*flush_offset..end]))?; + *flush_offset += n; + Poll::Ready(Ok(n)) + } +} + +impl AsyncWrite for LazyWriteMultiPart { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let multipart_threshold = self.multipart_threshold; + let store = Arc::clone(&self.store); + let path = self.path.clone(); + + let mut wrote = 0; + loop { + match &mut self.state { + LazyWriteState::Buffer(buffer) => { + let buf_len = buf.len(); + let new_len = buffer.len() + buf_len; + + if new_len > multipart_threshold { + let new_buffer = std::mem::take(buffer); + let store = Arc::clone(&store); + let path = path.clone(); + let create_fut = Box::pin(async move { + let (_, multipart_writer) = store.put_multipart(&path).await?; + Ok(multipart_writer) + }); + self.state = LazyWriteState::CreateMultipart(create_fut, new_buffer); + } else { + buffer.extend_from_slice(buf); + return Poll::Ready(Ok(buf_len)); + } + } + LazyWriteState::CreateMultipart(_, _) => { + ready!(self.as_mut().poll_create_multipart(cx))?; + } + LazyWriteState::FlushingInitialWrite(writer, buffer, flush_offset) => { + let n = ready!(Self::do_inner_flush( + cx, + writer.as_mut().unwrap(), + buffer, + flush_offset, + buf.len() + ))?; + + if *flush_offset == buffer.len() { + wrote += n; + self.state = LazyWriteState::AsyncWrite(writer.take().unwrap()); + } else { + buffer.extend_from_slice(buf); + return Poll::Ready(Ok(n)); + } + } + LazyWriteState::AsyncWrite(writer) => { + return Pin::new(writer).poll_write(cx, buf).map_ok(|n| n + wrote) + } + LazyWriteState::Put(_) => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Cannot write after shutdown.", + )) + .into() + } + } + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + loop { + match &mut self.state { + LazyWriteState::CreateMultipart(_, _) => { + ready!(self.as_mut().poll_create_multipart(cx))?; + } + LazyWriteState::FlushingInitialWrite(writer, buffer, flush_offset) => { + ready!(Self::do_inner_flush( + cx, + writer.as_mut().unwrap(), + buffer, + flush_offset, + buffer.len(), + )?); + if *flush_offset == buffer.len() { + self.state = LazyWriteState::AsyncWrite(writer.take().unwrap()); + } else { + return Poll::Pending; + } + } + LazyWriteState::AsyncWrite(writer) => return Pin::new(writer).poll_flush(cx), + LazyWriteState::Buffer(_) => return Poll::Ready(Ok(())), + LazyWriteState::Put(_) => return Poll::Ready(Ok(())), + } + } + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + let store = Arc::clone(&self.store); + let path = self.path.clone(); + + loop { + match &mut self.state { + LazyWriteState::Buffer(buffer) => { + let store = Arc::clone(&store); + let path = path.clone(); + let buffer = std::mem::take(buffer); + let put_task = Box::pin(async move { + store.put(&path, buffer.into()).await?; + Ok(()) + }); + self.state = LazyWriteState::Put(put_task); + } + LazyWriteState::AsyncWrite(writer) => return Pin::new(writer).poll_shutdown(cx), + LazyWriteState::Put(fut) => return Pin::new(fut).poll(cx), + // handled by flush + _ => { + unreachable!(); + } + } + } + } +} + +impl std::fmt::Debug for LazyWriteMultiPart { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LazyWriteMultiPart") + .field("store", &self.store) + .field("path", &self.path) + .field("multipart_threshold", &self.multipart_threshold) + .finish() + } +}