Skip to content

Commit

Permalink
feat: implement concurrent MultipartUploadWriter (#3915)
Browse files Browse the repository at this point in the history
* feat: expose the concurrent field

* feat: implement concurrent `MultipartUploadWriter`

* chore: apply suggestions from CR

* chore: apply suggestions from CR

* chore: correct comments

* fix: clear future queue while aborting
  • Loading branch information
WenyXu authored Jan 5, 2024
1 parent a9858cd commit 23db343
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 103 deletions.
9 changes: 9 additions & 0 deletions core/src/raw/futures_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ where
}
}

/// Drop all tasks.
pub fn clear(&mut self) {
match &mut self.tasks {
Tasks::Once(fut) => *fut = None,
Tasks::Small(tasks) => tasks.clear(),
Tasks::Large(tasks) => *tasks = FuturesOrdered::new(),
}
}

/// Return the length of current concurrent futures (both ongoing and ready).
pub fn len(&self) -> usize {
match &self.tasks {
Expand Down
213 changes: 116 additions & 97 deletions core/src/raw/oio/write/multipart_upload_write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use std::pin::Pin;
use std::sync::Arc;
use std::task::ready;
use std::task::Context;
use std::task::Poll;

use async_trait::async_trait;
use futures::Future;
use futures::FutureExt;
use futures::StreamExt;

use crate::raw::*;
use crate::*;
Expand Down Expand Up @@ -103,44 +107,76 @@ pub struct MultipartUploadPart {
pub etag: String,
}

struct WritePartFuture(BoxedFuture<Result<MultipartUploadPart>>);

/// # Safety
///
/// wasm32 is a special target that we only have one event-loop for this WritePartFuture.
unsafe impl Send for WritePartFuture {}

/// # Safety
///
/// We will only take `&mut Self` reference for WritePartFuture.
unsafe impl Sync for WritePartFuture {}

impl Future for WritePartFuture {
type Output = Result<MultipartUploadPart>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.get_mut().0.poll_unpin(cx)
}
}

/// MultipartUploadWriter will implements [`Write`] based on multipart
/// uploads.
pub struct MultipartUploadWriter<W: MultipartUploadWrite> {
state: State<W>,
state: State,
w: Arc<W>,

cache: Option<oio::ChunkedBytes>,
upload_id: Option<Arc<String>>,
parts: Vec<MultipartUploadPart>,
cache: Option<oio::ChunkedBytes>,
futures: ConcurrentFutures<WritePartFuture>,
next_part_number: usize,
}

enum State<W> {
Idle(Option<W>),
Init(BoxedFuture<(W, Result<String>)>),
Write(BoxedFuture<(W, Result<MultipartUploadPart>)>),
Close(BoxedFuture<(W, Result<()>)>),
Abort(BoxedFuture<(W, Result<()>)>),
enum State {
Idle,
Init(BoxedFuture<Result<String>>),
Close(BoxedFuture<Result<()>>),
Abort(BoxedFuture<Result<()>>),
}

/// # Safety
///
/// wasm32 is a special target that we only have one event-loop for this state.
unsafe impl<S: MultipartUploadWrite> Send for State<S> {}
unsafe impl Send for State {}
/// # Safety
///
/// We will only take `&mut Self` reference for State.
unsafe impl<S: MultipartUploadWrite> Sync for State<S> {}
unsafe impl Sync for State {}

impl<W: MultipartUploadWrite> MultipartUploadWriter<W> {
/// Create a new MultipartUploadWriter.
pub fn new(inner: W) -> Self {
pub fn new(inner: W, concurrent: usize) -> Self {
Self {
state: State::Idle(Some(inner)),
state: State::Idle,

cache: None,
w: Arc::new(inner),
upload_id: None,
parts: Vec::new(),
cache: None,
futures: ConcurrentFutures::new(1.max(concurrent)),
next_part_number: 0,
}
}

fn fill_cache(&mut self, bs: &dyn oio::WriteBuf) -> usize {
let size = bs.remaining();
let bs = oio::ChunkedBytes::from_vec(bs.vectored_bytes(size));
assert!(self.cache.is_none());
self.cache = Some(bs);
size
}
}

impl<W> oio::Write for MultipartUploadWriter<W>
Expand All @@ -150,61 +186,49 @@ where
fn poll_write(&mut self, cx: &mut Context<'_>, bs: &dyn oio::WriteBuf) -> Poll<Result<usize>> {
loop {
match &mut self.state {
State::Idle(w) => {
State::Idle => {
match self.upload_id.as_ref() {
Some(upload_id) => {
let upload_id = upload_id.clone();
let part_number = self.parts.len();

let bs = self.cache.clone().expect("cache must be valid").clone();
let w = w.take().expect("writer must be valid");
self.state = State::Write(Box::pin(async move {
let size = bs.len();
let part = w
.write_part(
if self.futures.has_remaining() {
let cache = self.cache.take().expect("pending write must exist");
let part_number = self.next_part_number;
self.next_part_number += 1;
let w = self.w.clone();
let size = cache.len();
self.futures.push(WritePartFuture(Box::pin(async move {
w.write_part(
&upload_id,
part_number,
size as u64,
AsyncBody::ChunkedBytes(bs),
AsyncBody::ChunkedBytes(cache),
)
.await;

(w, part)
}));
.await
})));
let size = self.fill_cache(bs);
return Poll::Ready(Ok(size));
} else if let Some(part) = ready!(self.futures.poll_next_unpin(cx)) {
self.parts.push(part?);
}
}
None => {
// Fill cache with the first write.
if self.cache.is_none() {
let size = bs.remaining();
let cb = oio::ChunkedBytes::from_vec(bs.vectored_bytes(size));
self.cache = Some(cb);
let size = self.fill_cache(bs);
return Poll::Ready(Ok(size));
}

let w = w.take().expect("writer must be valid");
self.state = State::Init(Box::pin(async move {
let upload_id = w.initiate_part().await;
(w, upload_id)
}));
let w = self.w.clone();
self.state =
State::Init(Box::pin(async move { w.initiate_part().await }));
}
}
}
State::Init(fut) => {
let (w, upload_id) = ready!(fut.as_mut().poll(cx));
self.state = State::Idle(Some(w));
let upload_id = ready!(fut.as_mut().poll(cx));
self.state = State::Idle;
self.upload_id = Some(Arc::new(upload_id?));
}
State::Write(fut) => {
let (w, part) = ready!(fut.as_mut().poll(cx));
self.state = State::Idle(Some(w));
self.parts.push(part?);

// Replace the cache when last write succeeded
let size = bs.remaining();
let cb = oio::ChunkedBytes::from_vec(bs.vectored_bytes(size));
self.cache = Some(cb);
return Poll::Ready(Ok(size));
}
State::Close(_) => {
unreachable!(
"MultipartUploadWriter must not go into State::Close during poll_write"
Expand All @@ -222,73 +246,71 @@ where
fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
loop {
match &mut self.state {
State::Idle(w) => {
let w = w.take().expect("writer must be valid");
State::Idle => {
match self.upload_id.clone() {
Some(upload_id) => {
let parts = self.parts.clone();
match self.cache.clone() {
Some(bs) => {
let upload_id = upload_id.clone();
self.state = State::Write(Box::pin(async move {
let size = bs.len();
let part = w
.write_part(
let w = self.w.clone();
if self.futures.is_empty() && self.cache.is_none() {
let upload_id = upload_id.clone();
let parts = self.parts.clone();
self.state = State::Close(Box::pin(async move {
w.complete_part(&upload_id, &parts).await
}));
} else {
if self.futures.has_remaining() {
if let Some(cache) = self.cache.take() {
let upload_id = upload_id.clone();
let part_number = self.next_part_number;
self.next_part_number += 1;
let size = cache.len();
let w = self.w.clone();
self.futures.push(WritePartFuture(Box::pin(async move {
w.write_part(
&upload_id,
parts.len(),
part_number,
size as u64,
AsyncBody::ChunkedBytes(bs),
AsyncBody::ChunkedBytes(cache),
)
.await;
(w, part)
}));
.await
})));
}
}
None => {
self.state = State::Close(Box::pin(async move {
let res = w.complete_part(&upload_id, &parts).await;
(w, res)
}));
while let Some(part) = ready!(self.futures.poll_next_unpin(cx)) {
self.parts.push(part?);
}
}
}
None => match self.cache.clone() {
Some(bs) => {
None => match &self.cache {
Some(cache) => {
let w = self.w.clone();
let bs = cache.clone();
self.state = State::Close(Box::pin(async move {
let size = bs.len();
let res = w
.write_once(size as u64, AsyncBody::ChunkedBytes(bs))
.await;
(w, res)
w.write_once(size as u64, AsyncBody::ChunkedBytes(bs)).await
}));
}
None => {
let w = self.w.clone();
// Call write_once if there is no data in cache and no upload_id.
self.state = State::Close(Box::pin(async move {
let res = w.write_once(0, AsyncBody::Empty).await;
(w, res)
w.write_once(0, AsyncBody::Empty).await
}));
}
},
}
}
State::Close(fut) => {
let (w, res) = futures::ready!(fut.as_mut().poll(cx));
self.state = State::Idle(Some(w));
let res = futures::ready!(fut.as_mut().poll(cx));
self.state = State::Idle;
// We should check res first before clean up cache.
res?;

self.cache = None;

return Poll::Ready(Ok(()));
}
State::Init(_) => unreachable!(
"MultipartUploadWriter must not go into State::Init during poll_close"
),
State::Write(fut) => {
let (w, part) = ready!(fut.as_mut().poll(cx));
self.state = State::Idle(Some(w));
self.parts.push(part?);
self.cache = None;
}
State::Abort(_) => unreachable!(
"MultipartUploadWriter must not go into State::Abort during poll_close"
),
Expand All @@ -299,32 +321,29 @@ where
fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
loop {
match &mut self.state {
State::Idle(w) => {
let w = w.take().expect("writer must be valid");
State::Idle => {
let w = self.w.clone();
match self.upload_id.clone() {
Some(upload_id) => {
self.state = State::Abort(Box::pin(async move {
let res = w.abort_part(&upload_id).await;
(w, res)
}));
self.futures.clear();
self.state =
State::Abort(Box::pin(
async move { w.abort_part(&upload_id).await },
));
}
None => {
self.cache = None;
return Poll::Ready(Ok(()));
}
}
}
State::Abort(fut) => {
let (w, res) = futures::ready!(fut.as_mut().poll(cx));
self.state = State::Idle(Some(w));
let res = futures::ready!(fut.as_mut().poll(cx));
self.state = State::Idle;
return Poll::Ready(res);
}
State::Init(_) => unreachable!(
"MultipartUploadWriter must not go into State::Init during poll_abort"
),
State::Write(_) => unreachable!(
"MultipartUploadWriter must not go into State::Write during poll_abort"
),
State::Close(_) => unreachable!(
"MultipartUploadWriter must not go into State::Close during poll_abort"
),
Expand Down
12 changes: 12 additions & 0 deletions core/src/raw/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ impl OpStat {
pub struct OpWrite {
append: bool,
buffer: Option<usize>,
concurrent: usize,

content_type: Option<String>,
content_disposition: Option<String>,
Expand Down Expand Up @@ -601,6 +602,17 @@ impl OpWrite {
self.cache_control = Some(cache_control.to_string());
self
}

/// Get the concurrent.
pub fn concurrent(&self) -> usize {
self.concurrent
}

/// Set the maximum concurrent write task amount.
pub fn with_concurrent(mut self, concurrent: usize) -> Self {
self.concurrent = concurrent;
self
}
}

/// Args for `copy` operation.
Expand Down
3 changes: 2 additions & 1 deletion core/src/services/b2/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,10 @@ impl Accessor for B2Backend {
}

async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> {
let concurrent = args.concurrent();
let writer = B2Writer::new(self.core.clone(), path, args);

let w = oio::MultipartUploadWriter::new(writer);
let w = oio::MultipartUploadWriter::new(writer, concurrent);

Ok((RpWrite::default(), w))
}
Expand Down
Loading

0 comments on commit 23db343

Please sign in to comment.