From e513bd91daa72c65e5d7b20d577a1e003016c871 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Mon, 26 Feb 2024 22:29:11 +1300 Subject: [PATCH 1/2] Add BufWriter --- object_store/src/buffered.rs | 163 ++++++++++++++++++++++++++++++++++- 1 file changed, 161 insertions(+), 2 deletions(-) diff --git a/object_store/src/buffered.rs b/object_store/src/buffered.rs index 3a1354f4f20a..02252581b97a 100644 --- a/object_store/src/buffered.rs +++ b/object_store/src/buffered.rs @@ -18,7 +18,7 @@ //! Utilities for performing tokio-style buffered IO use crate::path::Path; -use crate::{ObjectMeta, ObjectStore}; +use crate::{MultipartId, ObjectMeta, ObjectStore}; use bytes::Bytes; use futures::future::{BoxFuture, FutureExt}; use futures::ready; @@ -27,7 +27,7 @@ use std::io::{Error, ErrorKind, SeekFrom}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, ReadBuf}; +use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, AsyncWriteExt, ReadBuf}; /// The default buffer size used by [`BufReader`] pub const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024; @@ -205,6 +205,138 @@ impl AsyncBufRead for BufReader { } } +/// An async buffered writer compatible with the tokio IO traits +/// +/// Up to `capacity` bytes will be buffered in memory, and flushed on shutdown +/// using [`ObjectStore::put`]. If `capacity` is exceeded, data will instead be +/// streamed using [`ObjectStore::put_multipart`] +pub struct BufWriter { + capacity: usize, + state: BufWriterState, + multipart_id: Option, + store: Arc, +} + +impl std::fmt::Debug for BufWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufWriter") + .field("capacity", &self.capacity) + .field("multipart_id", &self.multipart_id) + .finish() + } +} + +type MultipartResult = (MultipartId, Box); + +enum BufWriterState { + /// Buffer up to capacity bytes + Buffer(Path, Vec), + /// [`ObjectStore::put_multipart`] + Prepare(BoxFuture<'static, std::io::Result>), + /// Write to a multipart upload + Write(Box), + /// [`ObjectStore::put`] + Put(BoxFuture<'static, std::io::Result<()>>), +} + +impl BufWriter { + /// Create a new [`BufWriter`] from the provided [`ObjectStore`] and [`Path`] + pub fn new(store: Arc, path: Path) -> Self { + Self::with_capacity(store, path, 10 * 1024 * 1024) + } + + /// Create a new [`BufWriter`] from the provided [`ObjectStore`], [`Path`] and `capacity` + pub fn with_capacity(store: Arc, path: Path, capacity: usize) -> Self { + Self { + capacity, + store, + state: BufWriterState::Buffer(path, Vec::with_capacity(1024)), + multipart_id: None, + } + } + + /// Returns the [`MultipartId`] if multipart upload + pub fn multipart_id(&self) -> Option<&MultipartId> { + self.multipart_id.as_ref() + } +} + +impl AsyncWrite for BufWriter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let cap = self.capacity; + loop { + return match &mut self.state { + BufWriterState::Write(write) => Pin::new(write).poll_write(cx, buf), + BufWriterState::Put(_) => panic!("Already shut down"), + BufWriterState::Prepare(f) => { + let (id, w) = ready!(f.poll_unpin(cx)?); + self.state = BufWriterState::Write(w); + self.multipart_id = Some(id); + continue; + } + BufWriterState::Buffer(path, b) => { + if b.len().saturating_add(buf.len()) >= cap { + let buffer = std::mem::take(b); + let path = std::mem::take(path); + let store = Arc::clone(&self.store); + self.state = BufWriterState::Prepare(Box::pin(async move { + let (id, mut writer) = store.put_multipart(&path).await?; + writer.write_all(&buffer).await?; + Ok((id, writer)) + })); + continue; + } + b.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + }; + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + return match &mut self.state { + BufWriterState::Buffer(_, _) => Poll::Ready(Ok(())), + BufWriterState::Write(write) => Pin::new(write).poll_flush(cx), + BufWriterState::Put(_) => panic!("Already shut down"), + BufWriterState::Prepare(f) => { + let (id, w) = ready!(f.poll_unpin(cx)?); + self.state = BufWriterState::Write(w); + self.multipart_id = Some(id); + continue; + } + }; + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match &mut self.state { + BufWriterState::Prepare(f) => { + let (id, w) = ready!(f.poll_unpin(cx)?); + self.state = BufWriterState::Write(w); + self.multipart_id = Some(id); + } + BufWriterState::Buffer(p, b) => { + let buf = std::mem::take(b); + let path = std::mem::take(p); + let store = Arc::clone(&self.store); + self.state = BufWriterState::Put(Box::pin(async move { + store.put(&path, buf.into()).await?; + Ok(()) + })); + } + BufWriterState::Put(f) => return f.poll_unpin(cx), + BufWriterState::Write(w) => return Pin::new(w).poll_shutdown(cx), + } + } + } +} + /// Port of standardised function as requires Rust 1.66 /// /// @@ -300,4 +432,31 @@ mod tests { assert!(buffer.is_empty()); } } + + #[tokio::test] + async fn test_buf_writer() { + let store = Arc::new(InMemory::new()) as Arc; + let path = Path::from("file.txt"); + + // Test put + let mut writer = BufWriter::with_capacity(Arc::clone(&store), path.clone(), 30); + writer.write_all(&[0; 20]).await.unwrap(); + writer.flush().await.unwrap(); + writer.write_all(&[0; 5]).await.unwrap(); + assert!(writer.multipart_id().is_none()); + writer.shutdown().await.unwrap(); + assert!(writer.multipart_id().is_none()); + assert_eq!(store.head(&path).await.unwrap().size, 25); + + // Test multipart + let mut writer = BufWriter::with_capacity(Arc::clone(&store), path.clone(), 30); + writer.write_all(&[0; 20]).await.unwrap(); + writer.flush().await.unwrap(); + writer.write_all(&[0; 20]).await.unwrap(); + assert!(writer.multipart_id().is_some()); + writer.shutdown().await.unwrap(); + assert!(writer.multipart_id().is_some()); + + assert_eq!(store.head(&path).await.unwrap().size, 40); + } } From 58eba00b96378a4ffccadb414185d1a9439c1832 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Tue, 27 Feb 2024 08:56:03 +1300 Subject: [PATCH 2/2] Review feedback --- object_store/src/buffered.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/object_store/src/buffered.rs b/object_store/src/buffered.rs index 02252581b97a..fdefe599f79e 100644 --- a/object_store/src/buffered.rs +++ b/object_store/src/buffered.rs @@ -250,7 +250,7 @@ impl BufWriter { Self { capacity, store, - state: BufWriterState::Buffer(path, Vec::with_capacity(1024)), + state: BufWriterState::Buffer(path, Vec::new()), multipart_id: None, } }