From 9f36c883459405ecd9a5f4fdfa9a3317ab52302c Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Fri, 29 Mar 2024 10:14:35 +0000 Subject: [PATCH] Implement MultipartStore for ThrottledStore (#5533) * Implement MultipartStore for ThrottledStore Limit concurrency in BufWriter Tweak WriteMultipart * Fix MSRV * Format --- object_store/src/buffered.rs | 14 +++++++ object_store/src/throttle.rs | 78 ++++++++++++++++++++++++++++++++---- object_store/src/upload.rs | 76 +++++++++++++++++++++++++++++------ 3 files changed, 148 insertions(+), 20 deletions(-) diff --git a/object_store/src/buffered.rs b/object_store/src/buffered.rs index 39f8eafbef7e..de6d4eb1bb9c 100644 --- a/object_store/src/buffered.rs +++ b/object_store/src/buffered.rs @@ -216,6 +216,7 @@ impl AsyncBufRead for BufReader { /// streamed using [`ObjectStore::put_multipart`] pub struct BufWriter { capacity: usize, + max_concurrency: usize, state: BufWriterState, store: Arc, } @@ -250,10 +251,21 @@ impl BufWriter { Self { capacity, store, + max_concurrency: 8, state: BufWriterState::Buffer(path, Vec::new()), } } + /// Override the maximum number of in-flight requests for this writer + /// + /// Defaults to 8 + pub fn with_max_concurrency(self, max_concurrency: usize) -> Self { + Self { + max_concurrency, + ..self + } + } + /// Abort this writer, cleaning up any partially uploaded state /// /// # Panic @@ -275,9 +287,11 @@ impl AsyncWrite for BufWriter { buf: &[u8], ) -> Poll> { let cap = self.capacity; + let max_concurrency = self.max_concurrency; loop { return match &mut self.state { BufWriterState::Write(Some(write)) => { + ready!(write.poll_for_capacity(cx, max_concurrency))?; write.write(buf); Poll::Ready(Ok(buf.len())) } diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs index 5ca1eedbf739..65fac5922f69 100644 --- a/object_store/src/throttle.rs +++ b/object_store/src/throttle.rs @@ -20,11 +20,12 @@ use parking_lot::Mutex; use std::ops::Range; use std::{convert::TryInto, sync::Arc}; -use crate::GetOptions; +use crate::multipart::{MultipartStore, PartId}; use crate::{ - path::Path, GetResult, GetResultPayload, ListResult, MultipartUpload, ObjectMeta, ObjectStore, - PutOptions, PutResult, Result, + path::Path, GetResult, GetResultPayload, ListResult, MultipartId, MultipartUpload, ObjectMeta, + ObjectStore, PutOptions, PutResult, Result, }; +use crate::{GetOptions, UploadPart}; use async_trait::async_trait; use bytes::Bytes; use futures::{stream::BoxStream, FutureExt, StreamExt}; @@ -110,12 +111,12 @@ async fn sleep(duration: Duration) { /// **Note that the behavior of the wrapper is deterministic and might not reflect real-world /// conditions!** #[derive(Debug)] -pub struct ThrottledStore { +pub struct ThrottledStore { inner: T, config: Arc>, } -impl ThrottledStore { +impl ThrottledStore { /// Create new wrapper with zero waiting times. pub fn new(inner: T, config: ThrottleConfig) -> Self { Self { @@ -157,8 +158,12 @@ impl ObjectStore for ThrottledStore { self.inner.put_opts(location, bytes, opts).await } - async fn put_multipart(&self, _location: &Path) -> Result> { - Err(super::Error::NotImplemented) + async fn put_multipart(&self, location: &Path) -> Result> { + let upload = self.inner.put_multipart(location).await?; + Ok(Box::new(ThrottledUpload { + upload, + sleep: self.config().wait_put_per_call, + })) } async fn get(&self, location: &Path) -> Result { @@ -316,6 +321,63 @@ where .boxed() } +#[async_trait] +impl MultipartStore for ThrottledStore { + async fn create_multipart(&self, path: &Path) -> Result { + self.inner.create_multipart(path).await + } + + async fn put_part( + &self, + path: &Path, + id: &MultipartId, + part_idx: usize, + data: Bytes, + ) -> Result { + sleep(self.config().wait_put_per_call).await; + self.inner.put_part(path, id, part_idx, data).await + } + + async fn complete_multipart( + &self, + path: &Path, + id: &MultipartId, + parts: Vec, + ) -> Result { + self.inner.complete_multipart(path, id, parts).await + } + + async fn abort_multipart(&self, path: &Path, id: &MultipartId) -> Result<()> { + self.inner.abort_multipart(path, id).await + } +} + +#[derive(Debug)] +struct ThrottledUpload { + upload: Box, + sleep: Duration, +} + +#[async_trait] +impl MultipartUpload for ThrottledUpload { + fn put_part(&mut self, data: Bytes) -> UploadPart { + let duration = self.sleep; + let put = self.upload.put_part(data); + Box::pin(async move { + sleep(duration).await; + put.await + }) + } + + async fn complete(&mut self) -> Result { + self.upload.complete().await + } + + async fn abort(&mut self) -> Result<()> { + self.upload.abort().await + } +} + #[cfg(test)] mod tests { use super::*; @@ -351,6 +413,8 @@ mod tests { list_with_delimiter(&store).await; rename_and_copy(&store).await; copy_if_not_exists(&store).await; + stream_get(&store).await; + multipart(&store, &store).await; } #[tokio::test] diff --git a/object_store/src/upload.rs b/object_store/src/upload.rs index 6f8bfa8a5f73..fe864e2821c9 100644 --- a/object_store/src/upload.rs +++ b/object_store/src/upload.rs @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::{PutResult, Result}; +use std::task::{Context, Poll}; + use async_trait::async_trait; use bytes::Bytes; use futures::future::BoxFuture; +use futures::ready; use tokio::task::JoinSet; +use crate::{PutResult, Result}; + /// An upload part request pub type UploadPart = BoxFuture<'static, Result<()>>; @@ -110,31 +114,44 @@ pub struct WriteMultipart { impl WriteMultipart { /// Create a new [`WriteMultipart`] that will upload using 5MB chunks pub fn new(upload: Box) -> Self { - Self::new_with_capacity(upload, 5 * 1024 * 1024) + Self::new_with_chunk_size(upload, 5 * 1024 * 1024) } - /// Create a new [`WriteMultipart`] that will upload in fixed `capacity` sized chunks - pub fn new_with_capacity(upload: Box, capacity: usize) -> Self { + /// Create a new [`WriteMultipart`] that will upload in fixed `chunk_size` sized chunks + pub fn new_with_chunk_size(upload: Box, chunk_size: usize) -> Self { Self { upload, - buffer: Vec::with_capacity(capacity), + buffer: Vec::with_capacity(chunk_size), tasks: Default::default(), } } - /// Wait until there are `max_concurrency` or fewer requests in-flight - pub async fn wait_for_capacity(&mut self, max_concurrency: usize) -> Result<()> { - while self.tasks.len() > max_concurrency { - self.tasks.join_next().await.unwrap()??; + /// Polls for there to be less than `max_concurrency` [`UploadPart`] in progress + /// + /// See [`Self::wait_for_capacity`] for an async version of this function + pub fn poll_for_capacity( + &mut self, + cx: &mut Context<'_>, + max_concurrency: usize, + ) -> Poll> { + while !self.tasks.is_empty() && self.tasks.len() >= max_concurrency { + ready!(self.tasks.poll_join_next(cx)).unwrap()?? } - Ok(()) + Poll::Ready(Ok(())) + } + + /// Wait until there are less than `max_concurrency` [`UploadPart`] in progress + /// + /// See [`Self::poll_for_capacity`] for a [`Poll`] version of this function + pub async fn wait_for_capacity(&mut self, max_concurrency: usize) -> Result<()> { + futures::future::poll_fn(|cx| self.poll_for_capacity(cx, max_concurrency)).await } /// Write data to this [`WriteMultipart`] /// - /// Note this method is synchronous (not `async`) and will immediately start new uploads - /// as soon as the internal `capacity` is hit, regardless of - /// how many outstanding uploads are already in progress. + /// Note this method is synchronous (not `async`) and will immediately + /// start new uploads as soon as the internal `chunk_size` is hit, + /// regardless of how many outstanding uploads are already in progress. /// /// Back pressure can optionally be applied to producers by calling /// [`Self::wait_for_capacity`] prior to calling this method @@ -173,3 +190,36 @@ impl WriteMultipart { self.upload.complete().await } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use futures::FutureExt; + + use crate::memory::InMemory; + use crate::path::Path; + use crate::throttle::{ThrottleConfig, ThrottledStore}; + use crate::ObjectStore; + + use super::*; + + #[tokio::test] + async fn test_concurrency() { + let config = ThrottleConfig { + wait_put_per_call: Duration::from_millis(1), + ..Default::default() + }; + + let path = Path::from("foo"); + let store = ThrottledStore::new(InMemory::new(), config); + let upload = store.put_multipart(&path).await.unwrap(); + let mut write = WriteMultipart::new_with_chunk_size(upload, 10); + + for _ in 0..20 { + write.write(&[0; 5]); + } + assert!(write.wait_for_capacity(10).now_or_never().is_none()); + write.wait_for_capacity(10).await.unwrap() + } +}