From 23db343b7e260e3ef768c6bd81d68dd7915be9d4 Mon Sep 17 00:00:00 2001 From: Weny Xu Date: Fri, 5 Jan 2024 21:57:35 +0900 Subject: [PATCH] feat: implement concurrent `MultipartUploadWriter` (#3915) * feat: expose the concurrent field * feat: implement concurrent `MultipartUploadWriter` * chore: apply suggestions from CR * chore: apply suggestions from CR * chore: correct comments * fix: clear future queue while aborting --- core/src/raw/futures_util.rs | 9 + .../raw/oio/write/multipart_upload_write.rs | 213 ++++++++++-------- core/src/raw/ops.rs | 12 + core/src/services/b2/backend.rs | 3 +- core/src/services/cos/backend.rs | 2 +- core/src/services/obs/backend.rs | 2 +- core/src/services/oss/backend.rs | 2 +- core/src/services/s3/backend.rs | 3 +- core/src/services/upyun/backend.rs | 3 +- core/src/types/operator/operator_futures.rs | 12 + 10 files changed, 158 insertions(+), 103 deletions(-) diff --git a/core/src/raw/futures_util.rs b/core/src/raw/futures_util.rs index 08583fe6d8e4..d7a1168b6340 100644 --- a/core/src/raw/futures_util.rs +++ b/core/src/raw/futures_util.rs @@ -101,6 +101,15 @@ where } } + /// Drop all tasks. + pub fn clear(&mut self) { + match &mut self.tasks { + Tasks::Once(fut) => *fut = None, + Tasks::Small(tasks) => tasks.clear(), + Tasks::Large(tasks) => *tasks = FuturesOrdered::new(), + } + } + /// Return the length of current concurrent futures (both ongoing and ready). pub fn len(&self) -> usize { match &self.tasks { diff --git a/core/src/raw/oio/write/multipart_upload_write.rs b/core/src/raw/oio/write/multipart_upload_write.rs index a0f92aa98ec6..66105b4d4a4f 100644 --- a/core/src/raw/oio/write/multipart_upload_write.rs +++ b/core/src/raw/oio/write/multipart_upload_write.rs @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. +use std::pin::Pin; use std::sync::Arc; use std::task::ready; use std::task::Context; use std::task::Poll; use async_trait::async_trait; +use futures::Future; +use futures::FutureExt; +use futures::StreamExt; use crate::raw::*; use crate::*; @@ -103,44 +107,76 @@ pub struct MultipartUploadPart { pub etag: String, } +struct WritePartFuture(BoxedFuture>); + +/// # Safety +/// +/// wasm32 is a special target that we only have one event-loop for this WritePartFuture. +unsafe impl Send for WritePartFuture {} + +/// # Safety +/// +/// We will only take `&mut Self` reference for WritePartFuture. +unsafe impl Sync for WritePartFuture {} + +impl Future for WritePartFuture { + type Output = Result; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.get_mut().0.poll_unpin(cx) + } +} + /// MultipartUploadWriter will implements [`Write`] based on multipart /// uploads. pub struct MultipartUploadWriter { - state: State, + state: State, + w: Arc, - cache: Option, upload_id: Option>, parts: Vec, + cache: Option, + futures: ConcurrentFutures, + next_part_number: usize, } -enum State { - Idle(Option), - Init(BoxedFuture<(W, Result)>), - Write(BoxedFuture<(W, Result)>), - Close(BoxedFuture<(W, Result<()>)>), - Abort(BoxedFuture<(W, Result<()>)>), +enum State { + Idle, + Init(BoxedFuture>), + Close(BoxedFuture>), + Abort(BoxedFuture>), } /// # Safety /// /// wasm32 is a special target that we only have one event-loop for this state. -unsafe impl Send for State {} +unsafe impl Send for State {} /// # Safety /// /// We will only take `&mut Self` reference for State. -unsafe impl Sync for State {} +unsafe impl Sync for State {} impl MultipartUploadWriter { /// Create a new MultipartUploadWriter. - pub fn new(inner: W) -> Self { + pub fn new(inner: W, concurrent: usize) -> Self { Self { - state: State::Idle(Some(inner)), + state: State::Idle, - cache: None, + w: Arc::new(inner), upload_id: None, parts: Vec::new(), + cache: None, + futures: ConcurrentFutures::new(1.max(concurrent)), + next_part_number: 0, } } + + fn fill_cache(&mut self, bs: &dyn oio::WriteBuf) -> usize { + let size = bs.remaining(); + let bs = oio::ChunkedBytes::from_vec(bs.vectored_bytes(size)); + assert!(self.cache.is_none()); + self.cache = Some(bs); + size + } } impl oio::Write for MultipartUploadWriter @@ -150,61 +186,49 @@ where fn poll_write(&mut self, cx: &mut Context<'_>, bs: &dyn oio::WriteBuf) -> Poll> { loop { match &mut self.state { - State::Idle(w) => { + State::Idle => { match self.upload_id.as_ref() { Some(upload_id) => { let upload_id = upload_id.clone(); - let part_number = self.parts.len(); - - let bs = self.cache.clone().expect("cache must be valid").clone(); - let w = w.take().expect("writer must be valid"); - self.state = State::Write(Box::pin(async move { - let size = bs.len(); - let part = w - .write_part( + if self.futures.has_remaining() { + let cache = self.cache.take().expect("pending write must exist"); + let part_number = self.next_part_number; + self.next_part_number += 1; + let w = self.w.clone(); + let size = cache.len(); + self.futures.push(WritePartFuture(Box::pin(async move { + w.write_part( &upload_id, part_number, size as u64, - AsyncBody::ChunkedBytes(bs), + AsyncBody::ChunkedBytes(cache), ) - .await; - - (w, part) - })); + .await + }))); + let size = self.fill_cache(bs); + return Poll::Ready(Ok(size)); + } else if let Some(part) = ready!(self.futures.poll_next_unpin(cx)) { + self.parts.push(part?); + } } None => { // Fill cache with the first write. if self.cache.is_none() { - let size = bs.remaining(); - let cb = oio::ChunkedBytes::from_vec(bs.vectored_bytes(size)); - self.cache = Some(cb); + let size = self.fill_cache(bs); return Poll::Ready(Ok(size)); } - let w = w.take().expect("writer must be valid"); - self.state = State::Init(Box::pin(async move { - let upload_id = w.initiate_part().await; - (w, upload_id) - })); + let w = self.w.clone(); + self.state = + State::Init(Box::pin(async move { w.initiate_part().await })); } } } State::Init(fut) => { - let (w, upload_id) = ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(w)); + let upload_id = ready!(fut.as_mut().poll(cx)); + self.state = State::Idle; self.upload_id = Some(Arc::new(upload_id?)); } - State::Write(fut) => { - let (w, part) = ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(w)); - self.parts.push(part?); - - // Replace the cache when last write succeeded - let size = bs.remaining(); - let cb = oio::ChunkedBytes::from_vec(bs.vectored_bytes(size)); - self.cache = Some(cb); - return Poll::Ready(Ok(size)); - } State::Close(_) => { unreachable!( "MultipartUploadWriter must not go into State::Close during poll_write" @@ -222,73 +246,71 @@ where fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { loop { match &mut self.state { - State::Idle(w) => { - let w = w.take().expect("writer must be valid"); + State::Idle => { match self.upload_id.clone() { Some(upload_id) => { - let parts = self.parts.clone(); - match self.cache.clone() { - Some(bs) => { - let upload_id = upload_id.clone(); - self.state = State::Write(Box::pin(async move { - let size = bs.len(); - let part = w - .write_part( + let w = self.w.clone(); + if self.futures.is_empty() && self.cache.is_none() { + let upload_id = upload_id.clone(); + let parts = self.parts.clone(); + self.state = State::Close(Box::pin(async move { + w.complete_part(&upload_id, &parts).await + })); + } else { + if self.futures.has_remaining() { + if let Some(cache) = self.cache.take() { + let upload_id = upload_id.clone(); + let part_number = self.next_part_number; + self.next_part_number += 1; + let size = cache.len(); + let w = self.w.clone(); + self.futures.push(WritePartFuture(Box::pin(async move { + w.write_part( &upload_id, - parts.len(), + part_number, size as u64, - AsyncBody::ChunkedBytes(bs), + AsyncBody::ChunkedBytes(cache), ) - .await; - (w, part) - })); + .await + }))); + } } - None => { - self.state = State::Close(Box::pin(async move { - let res = w.complete_part(&upload_id, &parts).await; - (w, res) - })); + while let Some(part) = ready!(self.futures.poll_next_unpin(cx)) { + self.parts.push(part?); } } } - None => match self.cache.clone() { - Some(bs) => { + None => match &self.cache { + Some(cache) => { + let w = self.w.clone(); + let bs = cache.clone(); self.state = State::Close(Box::pin(async move { let size = bs.len(); - let res = w - .write_once(size as u64, AsyncBody::ChunkedBytes(bs)) - .await; - (w, res) + w.write_once(size as u64, AsyncBody::ChunkedBytes(bs)).await })); } None => { + let w = self.w.clone(); // Call write_once if there is no data in cache and no upload_id. self.state = State::Close(Box::pin(async move { - let res = w.write_once(0, AsyncBody::Empty).await; - (w, res) + w.write_once(0, AsyncBody::Empty).await })); } }, } } State::Close(fut) => { - let (w, res) = futures::ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(w)); + let res = futures::ready!(fut.as_mut().poll(cx)); + self.state = State::Idle; // We should check res first before clean up cache. res?; - self.cache = None; + return Poll::Ready(Ok(())); } State::Init(_) => unreachable!( "MultipartUploadWriter must not go into State::Init during poll_close" ), - State::Write(fut) => { - let (w, part) = ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(w)); - self.parts.push(part?); - self.cache = None; - } State::Abort(_) => unreachable!( "MultipartUploadWriter must not go into State::Abort during poll_close" ), @@ -299,32 +321,29 @@ where fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { loop { match &mut self.state { - State::Idle(w) => { - let w = w.take().expect("writer must be valid"); + State::Idle => { + let w = self.w.clone(); match self.upload_id.clone() { Some(upload_id) => { - self.state = State::Abort(Box::pin(async move { - let res = w.abort_part(&upload_id).await; - (w, res) - })); + self.futures.clear(); + self.state = + State::Abort(Box::pin( + async move { w.abort_part(&upload_id).await }, + )); } None => { - self.cache = None; return Poll::Ready(Ok(())); } } } State::Abort(fut) => { - let (w, res) = futures::ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(w)); + let res = futures::ready!(fut.as_mut().poll(cx)); + self.state = State::Idle; return Poll::Ready(res); } State::Init(_) => unreachable!( "MultipartUploadWriter must not go into State::Init during poll_abort" ), - State::Write(_) => unreachable!( - "MultipartUploadWriter must not go into State::Write during poll_abort" - ), State::Close(_) => unreachable!( "MultipartUploadWriter must not go into State::Close during poll_abort" ), diff --git a/core/src/raw/ops.rs b/core/src/raw/ops.rs index dbf43bfaa79c..955a5884210e 100644 --- a/core/src/raw/ops.rs +++ b/core/src/raw/ops.rs @@ -516,6 +516,7 @@ impl OpStat { pub struct OpWrite { append: bool, buffer: Option, + concurrent: usize, content_type: Option, content_disposition: Option, @@ -601,6 +602,17 @@ impl OpWrite { self.cache_control = Some(cache_control.to_string()); self } + + /// Get the concurrent. + pub fn concurrent(&self) -> usize { + self.concurrent + } + + /// Set the maximum concurrent write task amount. + pub fn with_concurrent(mut self, concurrent: usize) -> Self { + self.concurrent = concurrent; + self + } } /// Args for `copy` operation. diff --git a/core/src/services/b2/backend.rs b/core/src/services/b2/backend.rs index 2c5128bad659..a6f29055fddb 100644 --- a/core/src/services/b2/backend.rs +++ b/core/src/services/b2/backend.rs @@ -376,9 +376,10 @@ impl Accessor for B2Backend { } async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> { + let concurrent = args.concurrent(); let writer = B2Writer::new(self.core.clone(), path, args); - let w = oio::MultipartUploadWriter::new(writer); + let w = oio::MultipartUploadWriter::new(writer, concurrent); Ok((RpWrite::default(), w)) } diff --git a/core/src/services/cos/backend.rs b/core/src/services/cos/backend.rs index 5c6d45885860..e0e58701fd68 100644 --- a/core/src/services/cos/backend.rs +++ b/core/src/services/cos/backend.rs @@ -340,7 +340,7 @@ impl Accessor for CosBackend { let w = if args.append() { CosWriters::Two(oio::AppendObjectWriter::new(writer)) } else { - CosWriters::One(oio::MultipartUploadWriter::new(writer)) + CosWriters::One(oio::MultipartUploadWriter::new(writer, args.concurrent())) }; Ok((RpWrite::default(), w)) diff --git a/core/src/services/obs/backend.rs b/core/src/services/obs/backend.rs index d00fbe423a47..ad3b9c98fc89 100644 --- a/core/src/services/obs/backend.rs +++ b/core/src/services/obs/backend.rs @@ -350,7 +350,7 @@ impl Accessor for ObsBackend { let w = if args.append() { ObsWriters::Two(oio::AppendObjectWriter::new(writer)) } else { - ObsWriters::One(oio::MultipartUploadWriter::new(writer)) + ObsWriters::One(oio::MultipartUploadWriter::new(writer, args.concurrent())) }; Ok((RpWrite::default(), w)) diff --git a/core/src/services/oss/backend.rs b/core/src/services/oss/backend.rs index 6e109e132175..79624a7dded8 100644 --- a/core/src/services/oss/backend.rs +++ b/core/src/services/oss/backend.rs @@ -492,7 +492,7 @@ impl Accessor for OssBackend { let w = if args.append() { OssWriters::Two(oio::AppendObjectWriter::new(writer)) } else { - OssWriters::One(oio::MultipartUploadWriter::new(writer)) + OssWriters::One(oio::MultipartUploadWriter::new(writer, args.concurrent())) }; Ok((RpWrite::default(), w)) diff --git a/core/src/services/s3/backend.rs b/core/src/services/s3/backend.rs index 94acf49c37de..2094fde0d93c 100644 --- a/core/src/services/s3/backend.rs +++ b/core/src/services/s3/backend.rs @@ -1085,9 +1085,10 @@ impl Accessor for S3Backend { } async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> { + let concurrent = args.concurrent(); let writer = S3Writer::new(self.core.clone(), path, args); - let w = oio::MultipartUploadWriter::new(writer); + let w = oio::MultipartUploadWriter::new(writer, concurrent); Ok((RpWrite::default(), w)) } diff --git a/core/src/services/upyun/backend.rs b/core/src/services/upyun/backend.rs index 091977f6f59a..a13764b945cc 100644 --- a/core/src/services/upyun/backend.rs +++ b/core/src/services/upyun/backend.rs @@ -316,9 +316,10 @@ impl Accessor for UpyunBackend { } async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> { + let concurrent = args.concurrent(); let writer = UpyunWriter::new(self.core.clone(), args, path.to_string()); - let w = oio::MultipartUploadWriter::new(writer); + let w = oio::MultipartUploadWriter::new(writer, concurrent); Ok((RpWrite::default(), w)) } diff --git a/core/src/types/operator/operator_futures.rs b/core/src/types/operator/operator_futures.rs index 8c75f56b3cf3..4faf5ea52301 100644 --- a/core/src/types/operator/operator_futures.rs +++ b/core/src/types/operator/operator_futures.rs @@ -478,6 +478,12 @@ impl FutureWrite { .map_args(|(args, bs)| (args.with_cache_control(v), bs)); self } + + /// Set the maximum concurrent write task amount. + pub fn concurrent(mut self, v: usize) -> Self { + self.0 = self.0.map_args(|(args, bs)| (args.with_buffer(v), bs)); + self + } } impl Future for FutureWrite { @@ -543,6 +549,12 @@ impl FutureWriter { self.0 = self.0.map_args(|args| args.with_cache_control(v)); self } + + /// Set the maximum concurrent write task amount. + pub fn concurrent(mut self, v: usize) -> Self { + self.0 = self.0.map_args(|args| args.with_concurrent(v)); + self + } } impl Future for FutureWriter {