Skip to content

Commit

Permalink
Implement a lazy multipart writer
Browse files Browse the repository at this point in the history
Signed-off-by: 🐼 Samrose Ahmed 🐼 <[email protected]>
  • Loading branch information
Samrose-Ahmed committed Dec 16, 2023
1 parent 802ed42 commit 372bcbd
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 1 deletion.
32 changes: 32 additions & 0 deletions object_store/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2115,6 +2115,38 @@ mod tests {
assert_eq!(meta.size, chunk_size * 2);
}

pub(crate) async fn multipart_lazy(storage: Arc<DynObjectStore>) {
let path = Path::from("test_multipart_lazy");
let chunk_size = 5 * 1024 * 1024;

let chunks = get_chunks(chunk_size, 1);
let mut w = crate::multipart::put_multipart_lazy(
Arc::clone(&storage),
path.clone(),
10 * 1024 * 1024,
);
for chunk in chunks {
w.write_all(&chunk).await.unwrap();
}
w.shutdown().await.unwrap();

let meta = storage.head(&path).await.unwrap();
assert_eq!(meta.size, chunk_size);

let mut w = crate::multipart::put_multipart_lazy(
Arc::clone(&storage),
path.clone(),
10 * 1024 * 1024,
);
let chunks = get_chunks(chunk_size, 4);
for chunk in chunks {
w.write_all(&chunk).await.unwrap();
}
w.shutdown().await.unwrap();
let meta = storage.head(&path).await.unwrap();
assert_eq!(meta.size, chunk_size * 4);
}

#[cfg(any(feature = "aws", feature = "azure"))]
pub(crate) async fn tagging<F, Fut>(storage: &dyn ObjectStore, validate: bool, get_tags: F)
where
Expand Down
8 changes: 8 additions & 0 deletions object_store/src/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,14 @@ mod tests {
);
}

#[tokio::test]
async fn test_multipart_lazy() {
let root = TempDir::new().unwrap();
let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap();
let store = Arc::new(integration);
crate::tests::multipart_lazy(store).await;
}

#[tokio::test]
async fn filesystem_filename_with_percent() {
let temp_dir = TempDir::new().unwrap();
Expand Down
215 changes: 214 additions & 1 deletion object_store/src/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
use async_trait::async_trait;
use bytes::Bytes;
use futures::{stream::FuturesUnordered, Future, StreamExt};
use std::task::ready;
use std::{io, pin::Pin, sync::Arc, task::Poll};
use tokio::io::AsyncWrite;

use crate::path::Path;
use crate::{MultipartId, PutResult, Result};
use crate::{MultipartId, ObjectStore, PutResult, Result};

type BoxedTryFuture<T> = Pin<Box<dyn Future<Output = Result<T, io::Error>> + Send>>;

Expand Down Expand Up @@ -316,3 +317,215 @@ pub trait MultiPartStore: Send + Sync + 'static {
/// Aborts a multipart upload
async fn abort_multipart(&self, path: &Path, id: &MultipartId) -> Result<()>;
}

/// Create a lazy multipart writer for a given [`ObjectStore`] and [`Path`].
///
/// A multipart upload using `ObjectStore::put_multipart` will only be created when the size exceeds `multipart_threshold`,
/// otherwise a direct PUT will be performed on shutdown.
pub fn put_multipart_lazy(
store: Arc<dyn ObjectStore>,
path: Path,
multipart_threshold: usize,
) -> Box<dyn AsyncWrite + Send + Unpin> {
Box::new(LazyWriteMultiPart::new(store, path, multipart_threshold))
}

enum LazyWriteState {
/// Buffering data, not yet reached multipart threshold
Buffer(Vec<u8>),
/// Writer shutdown, putting data in progress
Put(BoxedTryFuture<()>),
/// Multipart threshold reached, creating a new multipart upload
CreateMultipart(BoxedTryFuture<Box<dyn AsyncWrite + Send + Unpin>>, Vec<u8>),
/// Writing the buffered data from before creation of multipart upload
FlushingInitialWrite(Option<Box<dyn AsyncWrite + Send + Unpin>>, Vec<u8>, usize),
/// Delegate to underlying multipart writer
AsyncWrite(Box<dyn AsyncWrite + Send + Unpin>),
}

/// Wrapper around a [`ObjectStore`] and [`Path`] that implements [`AsyncWrite`]
struct LazyWriteMultiPart {
store: Arc<dyn ObjectStore>,
path: Path,
multipart_threshold: usize,
state: LazyWriteState,
}

impl LazyWriteMultiPart {
pub fn new(store: Arc<dyn ObjectStore>, path: Path, multipart_threshold: usize) -> Self {
Self {
store,
path,
multipart_threshold,
state: LazyWriteState::Buffer(Vec::new()),
}
}

fn poll_create_multipart(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
match &mut self.state {
LazyWriteState::CreateMultipart(fut, buffer) => {
let writer = ready!(Pin::new(fut).poll(cx))?;
if buffer.is_empty() {
self.state = LazyWriteState::AsyncWrite(writer);
} else {
let new_buffer = std::mem::take(buffer);
self.state = LazyWriteState::FlushingInitialWrite(Some(writer), new_buffer, 0);
}
Poll::Ready(Ok(()))
}
_ => unreachable!(),
}
}

fn do_inner_flush(
cx: &mut std::task::Context<'_>,
writer: &mut Box<dyn AsyncWrite + Send + Unpin>,
buffer: &mut Vec<u8>,
flush_offset: &mut usize,
write_len: usize,
) -> Poll<Result<usize, io::Error>> {
let end = std::cmp::min(*flush_offset + write_len, buffer.len());
let n = ready!(Pin::new(writer).poll_write(cx, &buffer[*flush_offset..end]))?;
*flush_offset += n;
Poll::Ready(Ok(n))
}
}

impl AsyncWrite for LazyWriteMultiPart {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let multipart_threshold = self.multipart_threshold;
let store = Arc::clone(&self.store);
let path = self.path.clone();

let mut wrote = 0;
loop {
match &mut self.state {
LazyWriteState::Buffer(buffer) => {
let buf_len = buf.len();
let new_len = buffer.len() + buf_len;

if new_len > multipart_threshold {
let new_buffer = std::mem::take(buffer);
let store = Arc::clone(&store);
let path = path.clone();
let create_fut = Box::pin(async move {
let (_, multipart_writer) = store.put_multipart(&path).await?;
Ok(multipart_writer)
});
self.state = LazyWriteState::CreateMultipart(create_fut, new_buffer);
} else {
buffer.extend_from_slice(buf);
return Poll::Ready(Ok(buf_len));
}
}
LazyWriteState::CreateMultipart(_, _) => {
ready!(self.as_mut().poll_create_multipart(cx))?;
}
LazyWriteState::FlushingInitialWrite(writer, buffer, flush_offset) => {
let n = ready!(Self::do_inner_flush(
cx,
writer.as_mut().unwrap(),
buffer,
flush_offset,
buf.len()
))?;

if *flush_offset == buffer.len() {
wrote += n;
self.state = LazyWriteState::AsyncWrite(writer.take().unwrap());
} else {
buffer.extend_from_slice(buf);
return Poll::Ready(Ok(n));
}
}
LazyWriteState::AsyncWrite(writer) => {
return Pin::new(writer).poll_write(cx, buf).map_ok(|n| n + wrote)
}
LazyWriteState::Put(_) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Cannot write after shutdown.",
))
.into()
}
}
}
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
loop {
match &mut self.state {
LazyWriteState::CreateMultipart(_, _) => {
ready!(self.as_mut().poll_create_multipart(cx))?;
}
LazyWriteState::FlushingInitialWrite(writer, buffer, flush_offset) => {
ready!(Self::do_inner_flush(
cx,
writer.as_mut().unwrap(),
buffer,
flush_offset,
buffer.len(),
)?);
if *flush_offset == buffer.len() {
self.state = LazyWriteState::AsyncWrite(writer.take().unwrap());
} else {
return Poll::Pending;
}
}
LazyWriteState::AsyncWrite(writer) => return Pin::new(writer).poll_flush(cx),
LazyWriteState::Buffer(_) => return Poll::Ready(Ok(())),
LazyWriteState::Put(_) => return Poll::Ready(Ok(())),
}
}
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
ready!(self.as_mut().poll_flush(cx))?;
let store = Arc::clone(&self.store);
let path = self.path.clone();

loop {
match &mut self.state {
LazyWriteState::Buffer(buffer) => {
let store = Arc::clone(&store);
let path = path.clone();
let buffer = std::mem::take(buffer);
let put_task = Box::pin(async move {
store.put(&path, buffer.into()).await?;
Ok(())
});
self.state = LazyWriteState::Put(put_task);
}
LazyWriteState::AsyncWrite(writer) => return Pin::new(writer).poll_shutdown(cx),
LazyWriteState::Put(fut) => return Pin::new(fut).poll(cx),
// handled by flush
_ => {
unreachable!();
}
}
}
}
}

impl std::fmt::Debug for LazyWriteMultiPart {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LazyWriteMultiPart")
.field("store", &self.store)
.field("path", &self.path)
.field("multipart_threshold", &self.multipart_threshold)
.finish()
}
}

0 comments on commit 372bcbd

Please sign in to comment.