From 6ebd75c65f840d840205e8e2d6d2e04318aab248 Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Thu, 14 Mar 2024 18:41:36 +0800 Subject: [PATCH] refactor(core/raw): Migrate `oio::Write` to async in trait (#4358) --- core/benches/oio/utils.rs | 15 +- core/benches/oio/write.rs | 2 +- core/src/layers/blocking.rs | 8 +- core/src/layers/complete.rs | 23 +- core/src/layers/concurrent_limit.rs | 14 +- core/src/layers/dtrace.rs | 24 +- core/src/layers/error_context.rs | 15 +- core/src/layers/logging.rs | 27 +- core/src/layers/madsim.rs | 12 +- core/src/layers/metrics.rs | 16 +- core/src/layers/minitrace.rs | 15 +- core/src/layers/oteltrace.rs | 15 +- core/src/layers/prometheus.rs | 18 +- core/src/layers/prometheus_client.rs | 16 +- core/src/layers/retry.rs | 217 +++++------- core/src/layers/throttle.rs | 19 +- core/src/layers/timeout.rs | 59 +--- core/src/layers/tracing.rs | 15 +- core/src/raw/adapters/kv/backend.rs | 66 +--- core/src/raw/adapters/typed_kv/backend.rs | 67 +--- core/src/raw/enum_utils.rs | 44 ++- core/src/raw/oio/write/api.rs | 121 +++---- core/src/raw/oio/write/append_write.rs | 101 ++---- core/src/raw/oio/write/block_write.rs | 241 ++++++-------- core/src/raw/oio/write/exact_buf_write.rs | 32 +- core/src/raw/oio/write/mod.rs | 1 - core/src/raw/oio/write/multipart_write.rs | 333 ++++++++----------- core/src/raw/oio/write/one_shot_write.rs | 95 ++---- core/src/raw/oio/write/range_write.rs | 274 ++++++--------- core/src/services/alluxio/writer.rs | 115 ++----- core/src/services/azblob/writer.rs | 6 - core/src/services/azdls/writer.rs | 3 - core/src/services/azfile/writer.rs | 7 +- core/src/services/b2/writer.rs | 2 - core/src/services/chainsafe/writer.rs | 2 - core/src/services/cos/writer.rs | 3 - core/src/services/dbfs/writer.rs | 2 - core/src/services/dropbox/writer.rs | 2 - core/src/services/fs/writer.rs | 79 ++--- core/src/services/ftp/writer.rs | 2 - core/src/services/gcs/writer.rs | 3 - core/src/services/gdrive/writer.rs | 3 - core/src/services/ghac/writer.rs | 137 ++------ core/src/services/github/writer.rs | 2 - core/src/services/hdfs/writer.rs | 69 +--- core/src/services/hdfs_native/writer.rs | 12 +- core/src/services/ipmfs/writer.rs | 2 - core/src/services/koofr/writer.rs | 2 - core/src/services/obs/writer.rs | 3 - core/src/services/onedrive/writer.rs | 2 - core/src/services/oss/writer.rs | 3 - core/src/services/pcloud/writer.rs | 2 - core/src/services/s3/writer.rs | 3 - core/src/services/seafile/writer.rs | 2 - core/src/services/sftp/writer.rs | 33 +- core/src/services/supabase/writer.rs | 2 - core/src/services/swift/writer.rs | 2 - core/src/services/upyun/writer.rs | 2 - core/src/services/vercel_artifacts/writer.rs | 2 - core/src/services/vercel_blob/writer.rs | 2 - core/src/services/webdav/writer.rs | 2 - core/src/services/webhdfs/writer.rs | 3 - core/src/services/yandex_disk/writer.rs | 2 - core/src/types/operator/operator.rs | 13 +- core/src/types/writer.rs | 149 +++++++-- 65 files changed, 945 insertions(+), 1640 deletions(-) diff --git a/core/benches/oio/utils.rs b/core/benches/oio/utils.rs index 130243bbd7a8..3b580eb6e8ba 100644 --- a/core/benches/oio/utils.rs +++ b/core/benches/oio/utils.rs @@ -15,9 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::task::Context; -use std::task::Poll; - use bytes::Bytes; use opendal::raw::oio; use rand::prelude::ThreadRng; @@ -27,16 +24,16 @@ use rand::RngCore; pub struct BlackHoleWriter; impl oio::Write for BlackHoleWriter { - fn poll_write(&mut self, _: &mut Context<'_>, bs: Bytes) -> Poll> { - Poll::Ready(Ok(bs.len())) + async fn write(&mut self, bs: Bytes) -> opendal::Result { + Ok(bs.len()) } - fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + async fn abort(&mut self) -> opendal::Result<()> { + Ok(()) } - fn poll_close(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + async fn close(&mut self) -> opendal::Result<()> { + Ok(()) } } diff --git a/core/benches/oio/write.rs b/core/benches/oio/write.rs index c8298767170b..84fded59de82 100644 --- a/core/benches/oio/write.rs +++ b/core/benches/oio/write.rs @@ -19,7 +19,7 @@ use bytes::Buf; use criterion::Criterion; use once_cell::sync::Lazy; use opendal::raw::oio::ExactBufWriter; -use opendal::raw::oio::WriteExt; +use opendal::raw::oio::Write; use rand::thread_rng; use size::Size; diff --git a/core/src/layers/blocking.rs b/core/src/layers/blocking.rs index 4e8cf250fe16..e44eda903bcd 100644 --- a/core/src/layers/blocking.rs +++ b/core/src/layers/blocking.rs @@ -18,7 +18,7 @@ use async_trait::async_trait; use bytes; use bytes::Bytes; -use futures::future::poll_fn; + use tokio::runtime::Handle; use crate::raw::*; @@ -299,13 +299,11 @@ impl oio::BlockingRead for BlockingWrapper { impl oio::BlockingWrite for BlockingWrapper { fn write(&mut self, bs: Bytes) -> Result { - self.handle - .block_on(poll_fn(|cx| self.inner.poll_write(cx, bs.clone()))) + self.handle.block_on(self.inner.write(bs)) } fn close(&mut self) -> Result<()> { - self.handle - .block_on(poll_fn(|cx| self.inner.poll_close(cx))) + self.handle.block_on(self.inner.close()) } } diff --git a/core/src/layers/complete.rs b/core/src/layers/complete.rs index 8120a74cbe0f..d4985a0f5235 100644 --- a/core/src/layers/complete.rs +++ b/core/src/layers/complete.rs @@ -18,10 +18,8 @@ use std::cmp; use std::fmt::Debug; use std::fmt::Formatter; + use std::sync::Arc; -use std::task::ready; -use std::task::Context; -use std::task::Poll; use async_trait::async_trait; use bytes::Bytes; @@ -158,7 +156,7 @@ impl CompleteAccessor { } if capability.write_can_empty && capability.list { let (_, mut w) = self.inner.write(path, OpWrite::default()).await?; - oio::WriteExt::close(&mut w).await?; + oio::Write::close(&mut w).await?; return Ok(RpCreateDir::default()); } @@ -712,35 +710,34 @@ impl oio::Write for CompleteWriter where W: oio::Write, { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> Result { let w = self.inner.as_mut().ok_or_else(|| { Error::new(ErrorKind::Unexpected, "writer has been closed or aborted") })?; - let n = ready!(w.poll_write(cx, bs))?; - Poll::Ready(Ok(n)) + w.write(bs).await } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn close(&mut self) -> Result<()> { let w = self.inner.as_mut().ok_or_else(|| { Error::new(ErrorKind::Unexpected, "writer has been closed or aborted") })?; - ready!(w.poll_close(cx))?; + w.close().await?; self.inner = None; - Poll::Ready(Ok(())) + Ok(()) } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn abort(&mut self) -> Result<()> { let w = self.inner.as_mut().ok_or_else(|| { Error::new(ErrorKind::Unexpected, "writer has been closed or aborted") })?; - ready!(w.poll_abort(cx))?; + w.abort().await?; self.inner = None; - Poll::Ready(Ok(())) + Ok(()) } } diff --git a/core/src/layers/concurrent_limit.rs b/core/src/layers/concurrent_limit.rs index 8dc043ff08ed..5fd6dd87bad2 100644 --- a/core/src/layers/concurrent_limit.rs +++ b/core/src/layers/concurrent_limit.rs @@ -19,8 +19,6 @@ use std::fmt::Debug; use std::io::SeekFrom; use std::sync::Arc; -use std::task::Context; -use std::task::Poll; use async_trait::async_trait; use bytes::Bytes; @@ -278,16 +276,16 @@ impl oio::BlockingRead for ConcurrentLimitWrapper { } impl oio::Write for ConcurrentLimitWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - self.inner.poll_write(cx, bs) + async fn write(&mut self, bs: Bytes) -> Result { + self.inner.write(bs).await } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_close(cx) + async fn close(&mut self) -> Result<()> { + self.inner.close().await } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_abort(cx) + async fn abort(&mut self) -> Result<()> { + self.inner.abort().await } } diff --git a/core/src/layers/dtrace.rs b/core/src/layers/dtrace.rs index ee14d8b6e292..11b62dc1c43c 100644 --- a/core/src/layers/dtrace.rs +++ b/core/src/layers/dtrace.rs @@ -18,9 +18,8 @@ use std::ffi::CString; use std::fmt::Debug; use std::fmt::Formatter; + use std::io; -use std::task::Context; -use std::task::Poll; use async_trait::async_trait; use bytes::Bytes; @@ -408,12 +407,13 @@ impl oio::BlockingRead for DtraceLayerWrapper { } impl oio::Write for DtraceLayerWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> Result { let c_path = CString::new(self.path.clone()).unwrap(); probe_lazy!(opendal, writer_write_start, c_path.as_ptr()); self.inner - .poll_write(cx, bs) - .map_ok(|n| { + .write(bs) + .await + .map(|n| { probe_lazy!(opendal, writer_write_ok, c_path.as_ptr(), n); n }) @@ -423,12 +423,13 @@ impl oio::Write for DtraceLayerWrapper { }) } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn abort(&mut self) -> Result<()> { let c_path = CString::new(self.path.clone()).unwrap(); probe_lazy!(opendal, writer_poll_abort_start, c_path.as_ptr()); self.inner - .poll_abort(cx) - .map_ok(|_| { + .abort() + .await + .map(|_| { probe_lazy!(opendal, writer_poll_abort_ok, c_path.as_ptr()); }) .map_err(|err| { @@ -437,12 +438,13 @@ impl oio::Write for DtraceLayerWrapper { }) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn close(&mut self) -> Result<()> { let c_path = CString::new(self.path.clone()).unwrap(); probe_lazy!(opendal, writer_close_start, c_path.as_ptr()); self.inner - .poll_close(cx) - .map_ok(|_| { + .close() + .await + .map(|_| { probe_lazy!(opendal, writer_close_ok, c_path.as_ptr()); }) .map_err(|err| { diff --git a/core/src/layers/error_context.rs b/core/src/layers/error_context.rs index 8b6ef863ee89..de37bc690b28 100644 --- a/core/src/layers/error_context.rs +++ b/core/src/layers/error_context.rs @@ -19,8 +19,6 @@ use std::fmt::Debug; use std::fmt::Formatter; use std::io::SeekFrom; -use std::task::Context; -use std::task::Poll; use async_trait::async_trait; use bytes::Bytes; @@ -387,10 +385,9 @@ impl oio::BlockingRead for ErrorContextWrapper { } } -#[async_trait::async_trait] impl oio::Write for ErrorContextWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - self.inner.poll_write(cx, bs.clone()).map_err(|err| { + async fn write(&mut self, bs: Bytes) -> Result { + self.inner.write(bs.clone()).await.map_err(|err| { err.with_operation(WriteOperation::Write) .with_context("service", self.scheme) .with_context("path", &self.path) @@ -398,16 +395,16 @@ impl oio::Write for ErrorContextWrapper { }) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_close(cx).map_err(|err| { + async fn close(&mut self) -> Result<()> { + self.inner.close().await.map_err(|err| { err.with_operation(WriteOperation::Close) .with_context("service", self.scheme) .with_context("path", &self.path) }) } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_abort(cx).map_err(|err| { + async fn abort(&mut self) -> Result<()> { + self.inner.abort().await.map_err(|err| { err.with_operation(WriteOperation::Abort) .with_context("service", self.scheme) .with_context("path", &self.path) diff --git a/core/src/layers/logging.rs b/core/src/layers/logging.rs index fb274391c0ec..516dd9ab7e7f 100644 --- a/core/src/layers/logging.rs +++ b/core/src/layers/logging.rs @@ -18,9 +18,6 @@ use std::fmt::Debug; use std::io; -use std::task::ready; -use std::task::Context; -use std::task::Poll; use async_trait::async_trait; use bytes::Bytes; @@ -1147,8 +1144,8 @@ impl LoggingWriter { } impl oio::Write for LoggingWriter { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - match ready!(self.inner.poll_write(cx, bs.clone())) { + async fn write(&mut self, bs: Bytes) -> Result { + match self.inner.write(bs.clone()).await { Ok(n) => { self.written += n as u64; trace!( @@ -1161,7 +1158,7 @@ impl oio::Write for LoggingWriter { bs.len(), n, ); - Poll::Ready(Ok(n)) + Ok(n) } Err(err) => { if let Some(lvl) = self.ctx.error_level(&err) { @@ -1176,13 +1173,13 @@ impl oio::Write for LoggingWriter { self.ctx.error_print(&err), ) } - Poll::Ready(Err(err)) + Err(err) } } } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - match ready!(self.inner.poll_abort(cx)) { + async fn abort(&mut self) -> Result<()> { + match self.inner.abort().await { Ok(_) => { trace!( target: LOGGING_TARGET, @@ -1192,7 +1189,7 @@ impl oio::Write for LoggingWriter { self.path, self.written, ); - Poll::Ready(Ok(())) + Ok(()) } Err(err) => { if let Some(lvl) = self.ctx.error_level(&err) { @@ -1207,13 +1204,13 @@ impl oio::Write for LoggingWriter { self.ctx.error_print(&err), ) } - Poll::Ready(Err(err)) + Err(err) } } } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - match ready!(self.inner.poll_close(cx)) { + async fn close(&mut self) -> Result<()> { + match self.inner.close().await { Ok(_) => { debug!( target: LOGGING_TARGET, @@ -1223,7 +1220,7 @@ impl oio::Write for LoggingWriter { self.path, self.written ); - Poll::Ready(Ok(())) + Ok(()) } Err(err) => { if let Some(lvl) = self.ctx.error_level(&err) { @@ -1238,7 +1235,7 @@ impl oio::Write for LoggingWriter { self.ctx.error_print(&err), ) } - Poll::Ready(Err(err)) + Err(err) } } } diff --git a/core/src/layers/madsim.rs b/core/src/layers/madsim.rs index f423706dc35f..7573fd27e620 100644 --- a/core/src/layers/madsim.rs +++ b/core/src/layers/madsim.rs @@ -291,7 +291,7 @@ pub struct MadsimWriter { } impl oio::Write for MadsimWriter { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> crate::Result { #[cfg(madsim)] { let req = Request::Write(self.path.to_string(), bs); @@ -307,15 +307,15 @@ impl oio::Write for MadsimWriter { } } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Err(Error::new( + async fn abort(&mut self) -> crate::Result<()> { + Err(Error::new( ErrorKind::Unsupported, "will be supported in the future", - ))) + )) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + async fn close(&mut self) -> crate::Result<()> { + Ok(()) } } diff --git a/core/src/layers/metrics.rs b/core/src/layers/metrics.rs index a73127f9c795..dc531cb10d3b 100644 --- a/core/src/layers/metrics.rs +++ b/core/src/layers/metrics.rs @@ -17,10 +17,10 @@ use std::fmt::Debug; use std::fmt::Formatter; +use std::future::Future; use std::io; use std::sync::Arc; -use std::task::Context; -use std::task::Poll; + use std::time::Instant; use async_trait::async_trait; @@ -820,9 +820,9 @@ impl oio::BlockingRead for MetricWrapper { } impl oio::Write for MetricWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + fn write(&mut self, bs: Bytes) -> impl Future> + Send { self.inner - .poll_write(cx, bs) + .write(bs) .map_ok(|n| { self.bytes += n as u64; n @@ -833,15 +833,15 @@ impl oio::Write for MetricWrapper { }) } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_abort(cx).map_err(|err| { + fn abort(&mut self) -> impl Future> + Send { + self.inner.abort().map_err(|err| { self.handle.increment_errors_total(self.op, err.kind()); err }) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_close(cx).map_err(|err| { + fn close(&mut self) -> impl Future> + Send { + self.inner.close().map_err(|err| { self.handle.increment_errors_total(self.op, err.kind()); err }) diff --git a/core/src/layers/minitrace.rs b/core/src/layers/minitrace.rs index cdb54312f3ee..60ee95f32758 100644 --- a/core/src/layers/minitrace.rs +++ b/core/src/layers/minitrace.rs @@ -16,10 +16,9 @@ // under the License. use std::fmt::Debug; +use std::future::Future; use std::io; -use std::task::Context; -use std::task::Poll; use async_trait::async_trait; use bytes::Bytes; @@ -324,22 +323,22 @@ impl oio::BlockingRead for MinitraceWrapper { } impl oio::Write for MinitraceWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + fn write(&mut self, bs: Bytes) -> impl Future> + Send { let _g = self.span.set_local_parent(); let _span = LocalSpan::enter_with_local_parent(WriteOperation::Write.into_static()); - self.inner.poll_write(cx, bs) + self.inner.write(bs) } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { + fn abort(&mut self) -> impl Future> + Send { let _g = self.span.set_local_parent(); let _span = LocalSpan::enter_with_local_parent(WriteOperation::Abort.into_static()); - self.inner.poll_abort(cx) + self.inner.abort() } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + fn close(&mut self) -> impl Future> + Send { let _g = self.span.set_local_parent(); let _span = LocalSpan::enter_with_local_parent(WriteOperation::Close.into_static()); - self.inner.poll_close(cx) + self.inner.close() } } diff --git a/core/src/layers/oteltrace.rs b/core/src/layers/oteltrace.rs index 1453cc2540ba..727a0bc6452c 100644 --- a/core/src/layers/oteltrace.rs +++ b/core/src/layers/oteltrace.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::future::Future; use std::io; -use std::task::Context; -use std::task::Poll; use async_trait::async_trait; use bytes::Bytes; @@ -298,16 +297,16 @@ impl oio::BlockingRead for OtelTraceWrapper { } impl oio::Write for OtelTraceWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - self.inner.poll_write(cx, bs) + fn write(&mut self, bs: Bytes) -> impl Future> + Send { + self.inner.write(bs) } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_abort(cx) + fn abort(&mut self) -> impl Future> + Send { + self.inner.abort() } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_close(cx) + fn close(&mut self) -> impl Future> + Send { + self.inner.close() } } diff --git a/core/src/layers/prometheus.rs b/core/src/layers/prometheus.rs index 64a4f61bace5..6b85844a2d0d 100644 --- a/core/src/layers/prometheus.rs +++ b/core/src/layers/prometheus.rs @@ -17,10 +17,9 @@ use std::fmt::Debug; use std::fmt::Formatter; + use std::io; use std::sync::Arc; -use std::task::Context; -use std::task::Poll; use async_trait::async_trait; use bytes::Bytes; @@ -749,15 +748,16 @@ impl oio::BlockingRead for PrometheusMetricWrapper { } impl oio::Write for PrometheusMetricWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> Result { let labels = self.stats.generate_metric_label( self.scheme.into_static(), Operation::Write.into_static(), &self.path, ); self.inner - .poll_write(cx, bs) - .map_ok(|n| { + .write(bs) + .await + .map(|n| { self.stats .bytes_total .with_label_values(&labels) @@ -770,15 +770,15 @@ impl oio::Write for PrometheusMetricWrapper { }) } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_abort(cx).map_err(|err| { + async fn abort(&mut self) -> Result<()> { + self.inner.abort().await.map_err(|err| { self.stats.increment_errors_total(self.op, err.kind()); err }) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_close(cx).map_err(|err| { + async fn close(&mut self) -> Result<()> { + self.inner.close().await.map_err(|err| { self.stats.increment_errors_total(self.op, err.kind()); err }) diff --git a/core/src/layers/prometheus_client.rs b/core/src/layers/prometheus_client.rs index 8558e4cab610..1276f5c69549 100644 --- a/core/src/layers/prometheus_client.rs +++ b/core/src/layers/prometheus_client.rs @@ -17,10 +17,10 @@ use std::fmt::Debug; use std::fmt::Formatter; +use std::future::Future; use std::io; use std::sync::Arc; -use std::task::Context; -use std::task::Poll; + use std::time::Duration; use std::time::Instant; @@ -590,9 +590,9 @@ impl oio::BlockingRead for PrometheusMetricWrapper { } impl oio::Write for PrometheusMetricWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + fn write(&mut self, bs: Bytes) -> impl Future> + Send { self.inner - .poll_write(cx, bs) + .write(bs) .map_ok(|n| { self.bytes_total += n; n @@ -604,16 +604,16 @@ impl oio::Write for PrometheusMetricWrapper { }) } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_abort(cx).map_err(|err| { + fn abort(&mut self) -> impl Future> + Send { + self.inner.abort().map_err(|err| { self.metrics .increment_errors_total(self.scheme, self.op, err.kind()); err }) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_close(cx).map_err(|err| { + fn close(&mut self) -> impl Future> + Send { + self.inner.close().map_err(|err| { self.metrics .increment_errors_total(self.scheme, self.op, err.kind()); err diff --git a/core/src/layers/retry.rs b/core/src/layers/retry.rs index de37c722c889..c73b11696a87 100644 --- a/core/src/layers/retry.rs +++ b/core/src/layers/retry.rs @@ -19,17 +19,14 @@ use std::fmt::Debug; use std::fmt::Formatter; use std::io; -use std::pin::Pin; + use std::sync::Arc; -use std::task::ready; -use std::task::Context; -use std::task::Poll; + use std::time::Duration; use async_trait::async_trait; -use backon::BackoffBuilder; use backon::BlockingRetryable; -use backon::ExponentialBackoff; + use backon::ExponentialBuilder; use backon::Retryable; use bytes::Bytes; @@ -653,8 +650,6 @@ pub struct RetryWrapper { path: String, builder: ExponentialBuilder, - current_backoff: Option, - sleep: Option>>, } impl RetryWrapper { @@ -665,8 +660,6 @@ impl RetryWrapper { path: path.to_string(), builder: backoff, - current_backoff: None, - sleep: None, } } } @@ -776,142 +769,100 @@ impl oio::BlockingRead for RetryWrapp } impl oio::Write for RetryWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - if let Some(sleep) = self.sleep.as_mut() { - ready!(sleep.poll_unpin(cx)); - self.sleep = None; - } + async fn write(&mut self, bs: Bytes) -> Result { + use backon::RetryableWithContext; - match ready!(self.inner.as_mut().unwrap().poll_write(cx, bs.clone())) { - Ok(v) => { - self.current_backoff = None; - Poll::Ready(Ok(v)) - } - Err(err) if !err.is_temporary() => { - self.current_backoff = None; - Poll::Ready(Err(err)) - } - Err(err) => { - let backoff = match self.current_backoff.as_mut() { - Some(backoff) => backoff, - None => { - self.current_backoff = Some(self.builder.build()); - self.current_backoff.as_mut().unwrap() - } - }; - - match backoff.next() { - None => { - self.current_backoff = None; - Poll::Ready(Err(err)) - } - Some(dur) => { - self.notify.intercept( - &err, - dur, - &[ - ("operation", WriteOperation::Write.into_static()), - ("path", &self.path), - ], - ); - self.sleep = Some(Box::pin(tokio::time::sleep(dur))); - self.poll_write(cx, bs.clone()) - } - } + let inner = self.inner.take().expect("inner must be valid"); + + let ((inner, _), res) = { + |(mut r, bs): (R, Bytes)| async move { + let res = r.write(bs.clone()).await; + + ((r, bs), res) } } + .retry(&self.builder) + .when(|e| e.is_temporary()) + .context((inner, bs)) + .notify(|err, dur| { + self.notify.intercept( + err, + dur, + &[ + ("operation", WriteOperation::Write.into_static()), + ("path", &self.path), + ], + ) + }) + .map(|(r, res)| (r, res.map_err(|err| err.set_persistent()))) + .await; + + self.inner = Some(inner); + res } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(sleep) = self.sleep.as_mut() { - ready!(sleep.poll_unpin(cx)); - self.sleep = None; - } + async fn abort(&mut self) -> Result<()> { + use backon::RetryableWithContext; - match ready!(self.inner.as_mut().unwrap().poll_abort(cx)) { - Ok(v) => { - self.current_backoff = None; - Poll::Ready(Ok(v)) - } - Err(err) if !err.is_temporary() => { - self.current_backoff = None; - Poll::Ready(Err(err)) - } - Err(err) => { - let backoff = match self.current_backoff.as_mut() { - Some(backoff) => backoff, - None => { - self.current_backoff = Some(self.builder.build()); - self.current_backoff.as_mut().unwrap() - } - }; - - match backoff.next() { - None => { - self.current_backoff = None; - Poll::Ready(Err(err)) - } - Some(dur) => { - self.notify.intercept( - &err, - dur, - &[ - ("operation", WriteOperation::Abort.into_static()), - ("path", &self.path), - ], - ); - self.sleep = Some(Box::pin(tokio::time::sleep(dur))); - self.poll_abort(cx) - } - } + let inner = self.inner.take().expect("inner must be valid"); + + let (inner, res) = { + |mut r: R| async move { + let res = r.abort().await; + + (r, res) } } + .retry(&self.builder) + .when(|e| e.is_temporary()) + .context(inner) + .notify(|err, dur| { + self.notify.intercept( + err, + dur, + &[ + ("operation", WriteOperation::Abort.into_static()), + ("path", &self.path), + ], + ) + }) + .map(|(r, res)| (r, res.map_err(|err| err.set_persistent()))) + .await; + + self.inner = Some(inner); + res } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - if let Some(sleep) = self.sleep.as_mut() { - ready!(sleep.poll_unpin(cx)); - self.sleep = None; - } + async fn close(&mut self) -> Result<()> { + use backon::RetryableWithContext; - match ready!(self.inner.as_mut().unwrap().poll_close(cx)) { - Ok(v) => { - self.current_backoff = None; - Poll::Ready(Ok(v)) - } - Err(err) if !err.is_temporary() => { - self.current_backoff = None; - Poll::Ready(Err(err)) - } - Err(err) => { - let backoff = match self.current_backoff.as_mut() { - Some(backoff) => backoff, - None => { - self.current_backoff = Some(self.builder.build()); - self.current_backoff.as_mut().unwrap() - } - }; - - match backoff.next() { - None => { - self.current_backoff = None; - Poll::Ready(Err(err)) - } - Some(dur) => { - self.notify.intercept( - &err, - dur, - &[ - ("operation", WriteOperation::Close.into_static()), - ("path", &self.path), - ], - ); - self.sleep = Some(Box::pin(tokio::time::sleep(dur))); - self.poll_close(cx) - } - } + let inner = self.inner.take().expect("inner must be valid"); + + let (inner, res) = { + |mut r: R| async move { + let res = r.close().await; + + (r, res) } } + .retry(&self.builder) + .when(|e| e.is_temporary()) + .context(inner) + .notify(|err, dur| { + self.notify.intercept( + err, + dur, + &[ + ("operation", WriteOperation::Close.into_static()), + ("path", &self.path), + ], + ) + }) + .map(|(r, res)| (r, res.map_err(|err| err.set_persistent()))) + .await; + + self.inner = Some(inner); + res } } diff --git a/core/src/layers/throttle.rs b/core/src/layers/throttle.rs index 514d536c8049..05a67f8ab9e1 100644 --- a/core/src/layers/throttle.rs +++ b/core/src/layers/throttle.rs @@ -18,8 +18,7 @@ use std::io::SeekFrom; use std::num::NonZeroU32; use std::sync::Arc; -use std::task::Context; -use std::task::Poll; + use std::thread; use async_trait::async_trait; @@ -207,13 +206,13 @@ impl oio::BlockingRead for ThrottleWrapper { } impl oio::Write for ThrottleWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> Result { let buf_length = NonZeroU32::new(bs.len() as u32).unwrap(); loop { match self.limiter.check_n(buf_length) { Ok(res) => match res { - Ok(_) => return self.inner.poll_write(cx, bs), + Ok(_) => return self.inner.write(bs).await, // the query is valid but the Decider can not accommodate them. Err(not_until) => { let _ = not_until.wait_time_from(DefaultClock::default().now()); @@ -224,20 +223,20 @@ impl oio::Write for ThrottleWrapper { } }, // the query was invalid as the rate limit parameters can "never" accommodate the number of cells queried for. - Err(_) => return Poll::Ready(Err(Error::new( + Err(_) => return Err(Error::new( ErrorKind::RateLimited, "InsufficientCapacity due to burst size being smaller than the request size", - ))), + )), } } } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_abort(cx) + async fn abort(&mut self) -> Result<()> { + self.inner.abort().await } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_close(cx) + async fn close(&mut self) -> Result<()> { + self.inner.close().await } } diff --git a/core/src/layers/timeout.rs b/core/src/layers/timeout.rs index 3e7906ed6a8f..34ba69df0ac3 100644 --- a/core/src/layers/timeout.rs +++ b/core/src/layers/timeout.rs @@ -17,10 +17,7 @@ use std::future::Future; use std::io::SeekFrom; -use std::pin::Pin; -use std::task::ready; -use std::task::Context; -use std::task::Poll; + use std::time::Duration; use async_trait::async_trait; @@ -274,16 +271,11 @@ pub struct TimeoutWrapper { inner: R, timeout: Duration, - sleep: Option>>, } impl TimeoutWrapper { fn new(inner: R, timeout: Duration) -> Self { - Self { - inner, - timeout, - sleep: None, - } + Self { inner, timeout } } #[inline] @@ -299,26 +291,6 @@ impl TimeoutWrapper { .set_temporary() })? } - - #[inline] - fn poll_timeout(&mut self, cx: &mut Context<'_>, op: &'static str) -> Result<()> { - let sleep = self - .sleep - .get_or_insert_with(|| Box::pin(tokio::time::sleep(self.timeout))); - - match sleep.as_mut().poll(cx) { - Poll::Pending => Ok(()), - Poll::Ready(_) => { - self.sleep = None; - Err( - Error::new(ErrorKind::Unexpected, "io operation timeout reached") - .with_operation(op) - .with_context("io_timeout", self.timeout.as_secs_f64().to_string()) - .set_temporary(), - ) - } - } - } } impl oio::Read for TimeoutWrapper { @@ -334,28 +306,19 @@ impl oio::Read for TimeoutWrapper { } impl oio::Write for TimeoutWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - self.poll_timeout(cx, WriteOperation::Write.into_static())?; - - let v = ready!(self.inner.poll_write(cx, bs)); - self.sleep = None; - Poll::Ready(v) + async fn write(&mut self, bs: Bytes) -> Result { + let fut = self.inner.write(bs); + Self::io_timeout(self.timeout, WriteOperation::Write.into_static(), fut).await } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - self.poll_timeout(cx, WriteOperation::Close.into_static())?; - - let v = ready!(self.inner.poll_close(cx)); - self.sleep = None; - Poll::Ready(v) + async fn close(&mut self) -> Result<()> { + let fut = self.inner.close(); + Self::io_timeout(self.timeout, WriteOperation::Close.into_static(), fut).await } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - self.poll_timeout(cx, WriteOperation::Abort.into_static())?; - - let v = ready!(self.inner.poll_abort(cx)); - self.sleep = None; - Poll::Ready(v) + async fn abort(&mut self) -> Result<()> { + let fut = self.inner.abort(); + Self::io_timeout(self.timeout, WriteOperation::Abort.into_static(), fut).await } } diff --git a/core/src/layers/tracing.rs b/core/src/layers/tracing.rs index 6d35ad52bac5..6d9a8b2447c0 100644 --- a/core/src/layers/tracing.rs +++ b/core/src/layers/tracing.rs @@ -16,10 +16,9 @@ // under the License. use std::fmt::Debug; +use std::future::Future; use std::io; -use std::task::Context; -use std::task::Poll; use async_trait::async_trait; use bytes::Bytes; @@ -309,24 +308,24 @@ impl oio::Write for TracingWrapper { parent = &self.span, level = "trace", skip_all)] - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - self.inner.poll_write(cx, bs) + fn write(&mut self, bs: Bytes) -> impl Future> + Send { + self.inner.write(bs) } #[tracing::instrument( parent = &self.span, level = "trace", skip_all)] - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_abort(cx) + fn abort(&mut self) -> impl Future> + Send { + self.inner.abort() } #[tracing::instrument( parent = &self.span, level = "trace", skip_all)] - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_close(cx) + fn close(&mut self) -> impl Future> + Send { + self.inner.close() } } diff --git a/core/src/raw/adapters/kv/backend.rs b/core/src/raw/adapters/kv/backend.rs index c6b6114e4384..08447eeb1e99 100644 --- a/core/src/raw/adapters/kv/backend.rs +++ b/core/src/raw/adapters/kv/backend.rs @@ -16,16 +16,12 @@ // under the License. use std::sync::Arc; -use std::task::ready; -use std::task::Context; -use std::task::Poll; + use std::vec::IntoIter; use async_trait::async_trait; use bytes::Bytes; use bytes::BytesMut; -use futures::future::BoxFuture; -use futures::FutureExt; use super::Adapter; use crate::raw::oio::HierarchyLister; @@ -265,7 +261,6 @@ pub struct KvWriter { path: String, buffer: Buffer, - future: Option>>, } impl KvWriter { @@ -274,7 +269,6 @@ impl KvWriter { kv, path, buffer: Buffer::Active(BytesMut::new()), - future: None, } } } @@ -290,62 +284,32 @@ enum Buffer { unsafe impl Sync for KvWriter {} impl oio::Write for KvWriter { - fn poll_write(&mut self, _: &mut Context<'_>, bs: Bytes) -> Poll> { - if self.future.is_some() { - self.future = None; - return Poll::Ready(Err(Error::new( - ErrorKind::Unexpected, - "there is a future on going, it's maybe a bug to go into this case", - ))); - } - + async fn write(&mut self, bs: Bytes) -> Result { match &mut self.buffer { Buffer::Active(buf) => { buf.extend_from_slice(&bs); - Poll::Ready(Ok(bs.len())) + Ok(bs.len()) } Buffer::Frozen(_) => unreachable!("KvWriter should not be frozen during poll_write"), } } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match self.future.as_mut() { - Some(fut) => { - let res = ready!(fut.poll_unpin(cx)); - self.future = None; - return Poll::Ready(res); - } - None => { - let kv = self.kv.clone(); - let path = self.path.clone(); - let buf = match &mut self.buffer { - Buffer::Active(buf) => { - let buf = buf.split().freeze(); - self.buffer = Buffer::Frozen(buf.clone()); - buf - } - Buffer::Frozen(buf) => buf.clone(), - }; - - let fut = async move { kv.set(&path, &buf).await }; - self.future = Some(Box::pin(fut)); - } + async fn close(&mut self) -> Result<()> { + let buf = match &mut self.buffer { + Buffer::Active(buf) => { + let buf = buf.split().freeze(); + self.buffer = Buffer::Frozen(buf.clone()); + buf } - } - } + Buffer::Frozen(buf) => buf.clone(), + }; - fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll> { - if self.future.is_some() { - self.future = None; - return Poll::Ready(Err(Error::new( - ErrorKind::Unexpected, - "there is a future on going, it's maybe a bug to go into this case", - ))); - } + self.kv.set(&self.path, &buf).await + } + async fn abort(&mut self) -> Result<()> { self.buffer = Buffer::Active(BytesMut::new()); - Poll::Ready(Ok(())) + Ok(()) } } diff --git a/core/src/raw/adapters/typed_kv/backend.rs b/core/src/raw/adapters/typed_kv/backend.rs index db6f741e4ba5..c33acfecdc84 100644 --- a/core/src/raw/adapters/typed_kv/backend.rs +++ b/core/src/raw/adapters/typed_kv/backend.rs @@ -16,15 +16,11 @@ // under the License. use std::sync::Arc; -use std::task::ready; -use std::task::Context; -use std::task::Poll; + use std::vec::IntoIter; use async_trait::async_trait; use bytes::Bytes; -use futures::future::BoxFuture; -use futures::FutureExt; use super::Adapter; use super::Value; @@ -271,7 +267,6 @@ pub struct KvWriter { op: OpWrite, buf: Option>, value: Option, - future: Option>>, } /// # Safety @@ -287,7 +282,6 @@ impl KvWriter { op, buf: None, value: None, - future: None, } } @@ -312,63 +306,32 @@ impl KvWriter { } impl oio::Write for KvWriter { - fn poll_write(&mut self, _: &mut Context<'_>, bs: Bytes) -> Poll> { - if self.future.is_some() { - self.future = None; - return Poll::Ready(Err(Error::new( - ErrorKind::Unexpected, - "there is a future on going, it's maybe a bug to go into this case", - ))); - } - + async fn write(&mut self, bs: Bytes) -> Result { let size = bs.len(); let mut buf = self.buf.take().unwrap_or_else(|| Vec::with_capacity(size)); buf.extend_from_slice(&bs); self.buf = Some(buf); - - Poll::Ready(Ok(size)) + Ok(size) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match self.future.as_mut() { - Some(fut) => { - let res = ready!(fut.poll_unpin(cx)); - self.future = None; - return Poll::Ready(res); - } - None => { - let kv = self.kv.clone(); - let path = self.path.clone(); - let value = match &self.value { - Some(value) => value.clone(), - None => { - let value = self.build(); - self.value = Some(value.clone()); - value - } - }; - - let fut = async move { kv.set(&path, value).await }; - self.future = Some(Box::pin(fut)); - } + async fn close(&mut self) -> Result<()> { + let value = match &self.value { + Some(value) => value.clone(), + None => { + let value = self.build(); + self.value = Some(value.clone()); + value } - } + }; + self.kv.set(&self.path, value).await?; + Ok(()) } - fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll> { - if self.future.is_some() { - self.future = None; - return Poll::Ready(Err(Error::new( - ErrorKind::Unexpected, - "there is a future on going, it's maybe a bug to go into this case", - ))); - } - + async fn abort(&mut self) -> Result<()> { self.buf = None; - Poll::Ready(Ok(())) + Ok(()) } } diff --git a/core/src/raw/enum_utils.rs b/core/src/raw/enum_utils.rs index 231f5a5fb3eb..be39f0df5bbf 100644 --- a/core/src/raw/enum_utils.rs +++ b/core/src/raw/enum_utils.rs @@ -39,8 +39,6 @@ //! type_alias_impl_trait has been stabilized. use std::io::SeekFrom; -use std::task::Context; -use std::task::Poll; use bytes::Bytes; @@ -90,24 +88,24 @@ impl oio::BlockingRead for TwoWa } impl oio::Write for TwoWays { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> Result { match self { - Self::One(v) => v.poll_write(cx, bs), - Self::Two(v) => v.poll_write(cx, bs), + Self::One(v) => v.write(bs).await, + Self::Two(v) => v.write(bs).await, } } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn close(&mut self) -> Result<()> { match self { - Self::One(v) => v.poll_close(cx), - Self::Two(v) => v.poll_close(cx), + Self::One(v) => v.close().await, + Self::Two(v) => v.close().await, } } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn abort(&mut self) -> Result<()> { match self { - Self::One(v) => v.poll_abort(cx), - Self::Two(v) => v.poll_abort(cx), + Self::One(v) => v.abort().await, + Self::Two(v) => v.abort().await, } } } @@ -165,27 +163,27 @@ impl o impl oio::Write for ThreeWays { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> Result { match self { - Self::One(v) => v.poll_write(cx, bs), - Self::Two(v) => v.poll_write(cx, bs), - Self::Three(v) => v.poll_write(cx, bs), + Self::One(v) => v.write(bs).await, + Self::Two(v) => v.write(bs).await, + Self::Three(v) => v.write(bs).await, } } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn close(&mut self) -> Result<()> { match self { - Self::One(v) => v.poll_close(cx), - Self::Two(v) => v.poll_close(cx), - Self::Three(v) => v.poll_close(cx), + Self::One(v) => v.close().await, + Self::Two(v) => v.close().await, + Self::Three(v) => v.close().await, } } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn abort(&mut self) -> Result<()> { match self { - Self::One(v) => v.poll_abort(cx), - Self::Two(v) => v.poll_abort(cx), - Self::Three(v) => v.poll_abort(cx), + Self::One(v) => v.abort().await, + Self::Two(v) => v.abort().await, + Self::Three(v) => v.abort().await, } } } diff --git a/core/src/raw/oio/write/api.rs b/core/src/raw/oio/write/api.rs index a352d362bcde..cb6707df7314 100644 --- a/core/src/raw/oio/write/api.rs +++ b/core/src/raw/oio/write/api.rs @@ -19,10 +19,9 @@ use bytes::Bytes; use std::fmt::Display; use std::fmt::Formatter; use std::future::Future; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; +use std::ops::DerefMut; +use crate::raw::*; use crate::*; /// WriteOperation is the name for APIs of Writer. @@ -71,7 +70,7 @@ impl From for &'static str { } /// Writer is a type erased [`Write`] -pub type Writer = Box; +pub type Writer = Box; /// Write is the trait that OpenDAL returns to callers. pub trait Write: Unpin + Send + Sync { @@ -84,117 +83,77 @@ pub trait Write: Unpin + Send + Sync { /// /// It's possible that `n < bs.len()`, caller should pass the remaining bytes /// repeatedly until all bytes has been written. - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll>; + #[cfg(not(target_arch = "wasm32"))] + fn write(&mut self, bs: Bytes) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn write(&mut self, bs: Bytes) -> impl Future>; /// Close the writer and make sure all data has been flushed. - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll>; + #[cfg(not(target_arch = "wasm32"))] + fn close(&mut self) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn close(&mut self) -> impl Future>; /// Abort the pending writer. - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll>; + #[cfg(not(target_arch = "wasm32"))] + fn abort(&mut self) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn abort(&mut self) -> impl Future>; } impl Write for () { - fn poll_write(&mut self, _: &mut Context<'_>, _: Bytes) -> Poll> { + async fn write(&mut self, _: Bytes) -> Result { unimplemented!("write is required to be implemented for oio::Write") } - fn poll_close(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Err(Error::new( + async fn close(&mut self) -> Result<()> { + Err(Error::new( ErrorKind::Unsupported, "output writer doesn't support close", - ))) + )) } - fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Err(Error::new( + async fn abort(&mut self) -> Result<()> { + Err(Error::new( ErrorKind::Unsupported, "output writer doesn't support abort", - ))) + )) } } -/// `Box` won't implement `Write` automatically. -/// -/// To make Writer work as expected, we must add this impl. -impl Write for Box { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - (**self).poll_write(cx, bs) - } +pub trait WriteDyn: Unpin + Send + Sync { + fn write_dyn(&mut self, bs: Bytes) -> BoxedFuture>; - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - (**self).poll_close(cx) - } + fn close_dyn(&mut self) -> BoxedFuture>; - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - (**self).poll_abort(cx) - } + fn abort_dyn(&mut self) -> BoxedFuture>; } -/// Impl WriteExt for all T: Write -impl WriteExt for T {} - -/// Extension of [`Read`] to make it easier for use. -pub trait WriteExt: Write { - /// Build a future for `poll_write`. - fn write(&mut self, buf: Bytes) -> WriteFuture<'_, Self> { - WriteFuture { writer: self, buf } +impl WriteDyn for T { + fn write_dyn(&mut self, bs: Bytes) -> BoxedFuture> { + Box::pin(self.write(bs)) } - /// Build a future for `poll_close`. - fn close(&mut self) -> CloseFuture { - CloseFuture { writer: self } + fn close_dyn(&mut self) -> BoxedFuture> { + Box::pin(self.close()) } - /// Build a future for `poll_abort`. - fn abort(&mut self) -> AbortFuture { - AbortFuture { writer: self } + fn abort_dyn(&mut self) -> BoxedFuture> { + Box::pin(self.abort()) } } -pub struct WriteFuture<'a, W: Write + Unpin + ?Sized> { - writer: &'a mut W, - buf: Bytes, -} - -impl Future for WriteFuture<'_, W> -where - W: Write + Unpin + ?Sized, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - this.writer.poll_write(cx, this.buf.clone()) +impl Write for Box { + async fn write(&mut self, bs: Bytes) -> Result { + self.deref_mut().write_dyn(bs).await } -} - -pub struct AbortFuture<'a, W: Write + Unpin + ?Sized> { - writer: &'a mut W, -} -impl Future for AbortFuture<'_, W> -where - W: Write + Unpin + ?Sized, -{ - type Output = Result<()>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.writer.poll_abort(cx) + async fn close(&mut self) -> Result<()> { + self.deref_mut().close_dyn().await } -} - -pub struct CloseFuture<'a, W: Write + Unpin + ?Sized> { - writer: &'a mut W, -} - -impl Future for CloseFuture<'_, W> -where - W: Write + Unpin + ?Sized, -{ - type Output = Result<()>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.writer.poll_close(cx) + async fn abort(&mut self) -> Result<()> { + self.deref_mut().abort_dyn().await } } diff --git a/core/src/raw/oio/write/append_write.rs b/core/src/raw/oio/write/append_write.rs index cec8ee023b8e..3d2c97c3bb71 100644 --- a/core/src/raw/oio/write/append_write.rs +++ b/core/src/raw/oio/write/append_write.rs @@ -15,11 +15,8 @@ // specific language governing permissions and limitations // under the License. -use std::task::ready; -use std::task::Context; -use std::task::Poll; +use std::future::Future; -use async_trait::async_trait; use bytes::Bytes; use crate::raw::*; @@ -41,16 +38,25 @@ use crate::*; /// /// - Must be a http service that could accept `AsyncBody`. /// - Provide a way to get the current offset of the append object. -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait AppendWrite: Send + Sync + Unpin + 'static { /// Get the current offset of the append object. /// /// Returns `0` if the object is not exist. - async fn offset(&self) -> Result; + #[cfg(not(target_arch = "wasm32"))] + fn offset(&self) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn offset(&self) -> impl Future>; /// Append the data to the end of this object. - async fn append(&self, offset: u64, size: u64, body: AsyncBody) -> Result<()>; + #[cfg(not(target_arch = "wasm32"))] + fn append( + &self, + offset: u64, + size: u64, + body: AsyncBody, + ) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn append(&self, offset: u64, size: u64, body: AsyncBody) -> impl Future>; } /// AppendWriter will implements [`Write`] based on append object. @@ -59,32 +65,20 @@ pub trait AppendWrite: Send + Sync + Unpin + 'static { /// /// - Allow users to switch to un-buffered mode if users write 16MiB every time. pub struct AppendWriter { - state: State, + inner: W, offset: Option, } -enum State { - Idle(Option), - Offset(BoxedStaticFuture<(W, Result)>), - Append(BoxedStaticFuture<(W, Result)>), -} - /// # Safety /// /// wasm32 is a special target that we only have one event-loop for this state. -unsafe impl Send for State {} - -/// # Safety -/// -/// We will only take `&mut Self` reference for State. -unsafe impl Sync for State {} impl AppendWriter { /// Create a new AppendWriter. pub fn new(inner: W) -> Self { Self { - state: State::Idle(Some(inner)), + inner, offset: None, } } @@ -94,53 +88,30 @@ impl oio::Write for AppendWriter where W: AppendWrite, { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - loop { - match &mut self.state { - State::Idle(w) => { - let w = w.take().expect("writer must be valid"); - match self.offset { - Some(offset) => { - let size = bs.len(); - let bs = bs.clone(); - self.state = State::Append(Box::pin(async move { - let res = w.append(offset, size as u64, AsyncBody::Bytes(bs)).await; - - (w, res.map(|_| size)) - })); - } - None => { - self.state = State::Offset(Box::pin(async move { - let offset = w.offset().await; - - (w, offset) - })); - } - } - } - State::Offset(fut) => { - let (w, offset) = ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(w)); - self.offset = Some(offset?); - } - State::Append(fut) => { - let (w, size) = ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(w)); - - let size = size?; - // Update offset after succeed. - self.offset = self.offset.map(|offset| offset + size as u64); - return Poll::Ready(Ok(size)); - } + async fn write(&mut self, bs: Bytes) -> Result { + let offset = match self.offset { + Some(offset) => offset, + None => { + let offset = self.inner.offset().await?; + self.offset = Some(offset); + offset } - } + }; + + let size = bs.len(); + self.inner + .append(offset, size as u64, AsyncBody::Bytes(bs)) + .await?; + // Update offset after succeed. + self.offset = Some(offset + size as u64); + Ok(size) } - fn poll_close(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + async fn close(&mut self) -> Result<()> { + Ok(()) } - fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + async fn abort(&mut self) -> Result<()> { + Ok(()) } } diff --git a/core/src/raw/oio/write/block_write.rs b/core/src/raw/oio/write/block_write.rs index 04aa76836516..2e836afe75e5 100644 --- a/core/src/raw/oio/write/block_write.rs +++ b/core/src/raw/oio/write/block_write.rs @@ -17,11 +17,10 @@ use std::pin::Pin; use std::sync::Arc; -use std::task::ready; + use std::task::Context; use std::task::Poll; -use async_trait::async_trait; use bytes::Bytes; use futures::Future; use futures::FutureExt; @@ -62,15 +61,16 @@ use crate::*; /// - Don't need initialization before writing. /// - Block ID is generated by caller `BlockWrite` instead of services. /// - Complete block by an ordered block id list. -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait BlockWrite: Send + Sync + Unpin + 'static { /// write_once is used to write the data to underlying storage at once. /// /// BlockWriter will call this API when: /// /// - All the data has been written to the buffer and we can perform the upload at once. - async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()>; + #[cfg(not(target_arch = "wasm32"))] + fn write_once(&self, size: u64, body: AsyncBody) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn write_once(&self, size: u64, body: AsyncBody) -> impl Future>; /// write_block will write a block of the data and returns the result /// [`Block`]. @@ -79,14 +79,33 @@ pub trait BlockWrite: Send + Sync + Unpin + 'static { /// order. /// /// - block_id is the id of the block. - async fn write_block(&self, block_id: Uuid, size: u64, body: AsyncBody) -> Result<()>; + #[cfg(not(target_arch = "wasm32"))] + fn write_block( + &self, + block_id: Uuid, + size: u64, + body: AsyncBody, + ) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn write_block( + &self, + block_id: Uuid, + size: u64, + body: AsyncBody, + ) -> impl Future>; /// complete_block will complete the block upload to build the final /// file. - async fn complete_block(&self, block_ids: Vec) -> Result<()>; + #[cfg(not(target_arch = "wasm32"))] + fn complete_block(&self, block_ids: Vec) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn complete_block(&self, block_ids: Vec) -> impl Future>; /// abort_block will cancel the block upload and purge all data. - async fn abort_block(&self, block_ids: Vec) -> Result<()>; + #[cfg(not(target_arch = "wasm32"))] + fn abort_block(&self, block_ids: Vec) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn abort_block(&self, block_ids: Vec) -> impl Future>; } /// WriteBlockResult is the result returned by [`WriteBlockFuture`]. @@ -135,7 +154,6 @@ impl WriteBlockFuture { /// BlockWriter will implements [`Write`] based on block /// uploads. pub struct BlockWriter { - state: State, w: Arc, block_ids: Vec, @@ -144,27 +162,14 @@ pub struct BlockWriter { futures: ConcurrentFutures, } -enum State { - Idle, - Close(BoxedStaticFuture>), - Abort(BoxedStaticFuture>), -} - /// # Safety /// /// wasm32 is a special target that we only have one event-loop for this state. -unsafe impl Send for State {} -/// # Safety -/// -/// We will only take `&mut Self` reference for State. -unsafe impl Sync for State {} impl BlockWriter { /// Create a new BlockWriter. pub fn new(inner: W, concurrent: usize) -> Self { Self { - state: State::Idle, - w: Arc::new(inner), block_ids: Vec::new(), cache: None, @@ -185,144 +190,92 @@ impl oio::Write for BlockWriter where W: BlockWrite, { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> Result { loop { - match &mut self.state { - State::Idle => { - if self.futures.has_remaining() { - // Fill cache with the first write. - if self.cache.is_none() { - let size = self.fill_cache(bs); - return Poll::Ready(Ok(size)); - } - - let cache = self.cache.take().expect("pending write must exist"); - self.futures.push_back(WriteBlockFuture::new( + if self.futures.has_remaining() { + // Fill cache with the first write. + if self.cache.is_none() { + let size = self.fill_cache(bs); + return Ok(size); + } + + let cache = self.cache.take().expect("pending write must exist"); + self.futures.push_back(WriteBlockFuture::new( + self.w.clone(), + Uuid::new_v4(), + cache, + )); + + let size = self.fill_cache(bs); + return Ok(size); + } else if let Some(res) = self.futures.next().await { + match res { + Ok(block_id) => { + self.block_ids.push(block_id); + continue; + } + Err((block_id, bytes, err)) => { + self.futures.push_front(WriteBlockFuture::new( self.w.clone(), - Uuid::new_v4(), - cache, + block_id, + bytes, )); - - let size = self.fill_cache(bs); - return Poll::Ready(Ok(size)); - } else if let Some(res) = ready!(self.futures.poll_next_unpin(cx)) { - match res { - Ok(block_id) => { - self.block_ids.push(block_id); - } - Err((block_id, bytes, err)) => { - self.futures.push_front(WriteBlockFuture::new( - self.w.clone(), - block_id, - bytes, - )); - return Poll::Ready(Err(err)); - } - } + return Err(err); } } - State::Close(_) => { - unreachable!("BlockWriter must not go into State::Close during poll_write") - } - State::Abort(_) => { - unreachable!("BlockWriter must not go into State::Abort during poll_write") - } } } } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn close(&mut self) -> Result<()> { + // No write block has been sent. + if self.futures.is_empty() && self.block_ids.is_empty() { + let (size, body) = match self.cache.clone() { + Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)), + None => (0, AsyncBody::Empty), + }; + self.w.write_once(size as u64, body).await?; + // Cleanup cache after write succeed. + self.cache = None; + return Ok(()); + } + loop { - match &mut self.state { - State::Idle => { - // No write block has been sent. - if self.futures.is_empty() && self.block_ids.is_empty() { - let w = self.w.clone(); - let (size, body) = match self.cache.clone() { - Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)), - None => (0, AsyncBody::Empty), - }; - // Call write_once if there is no data in buffer and no location. - self.state = - State::Close(Box::pin( - async move { w.write_once(size as u64, body).await }, - )); - continue; - } + if self.futures.has_remaining() { + // Push into the queue and continue. + // It's safe to take the cache here since we will re-push task for it failed. + if let Some(cache) = self.cache.take() { + self.futures.push_back(WriteBlockFuture::new( + self.w.clone(), + Uuid::new_v4(), + cache, + )); + } + } - if self.futures.has_remaining() { - if let Some(cache) = self.cache.take() { - self.futures.push_back(WriteBlockFuture::new( - self.w.clone(), - Uuid::new_v4(), - cache, - )); - } - } + let Some(result) = self.futures.next().await else { + break; + }; - if !self.futures.is_empty() { - while let Some(result) = ready!(self.futures.poll_next_unpin(cx)) { - match result { - Ok(block_id) => { - self.block_ids.push(block_id); - } - Err((block_id, bytes, err)) => { - self.futures.push_front(WriteBlockFuture::new( - self.w.clone(), - block_id, - bytes, - )); - return Poll::Ready(Err(err)); - } - } - } - } else { - let w = self.w.clone(); - let block_ids = self.block_ids.clone(); - self.state = - State::Close(Box::pin( - async move { w.complete_block(block_ids).await }, - )); - continue; - } + match result { + Ok(block_id) => { + self.block_ids.push(block_id); + continue; } - State::Close(fut) => { - let res = futures::ready!(fut.as_mut().poll(cx)); - self.state = State::Idle; - // We should check res first before clean up cache. - res?; - self.cache = None; - - return Poll::Ready(Ok(())); - } - State::Abort(_) => { - unreachable!("BlockWriter must not go into State::Abort during poll_close") + Err((block_id, bytes, err)) => { + self.futures + .push_front(WriteBlockFuture::new(self.w.clone(), block_id, bytes)); + return Err(err); } } } + + let block_ids = self.block_ids.clone(); + self.w.complete_block(block_ids).await } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match &mut self.state { - State::Idle => { - let w = self.w.clone(); - let block_ids = self.block_ids.clone(); - self.futures.clear(); - self.cache = None; - self.state = - State::Abort(Box::pin(async move { w.abort_block(block_ids).await })); - } - State::Abort(fut) => { - let res = futures::ready!(fut.as_mut().poll(cx)); - self.state = State::Idle; - return Poll::Ready(res); - } - State::Close(_) => { - unreachable!("BlockWriter must not go into State::Close during poll_abort") - } - } - } + async fn abort(&mut self) -> Result<()> { + self.w.abort_block(self.block_ids.clone()).await } } @@ -339,8 +292,8 @@ mod tests { use super::*; use crate::raw::oio::StreamExt; + use crate::raw::oio::Write; use crate::raw::oio::WriteBuf; - use crate::raw::oio::WriteExt; struct TestWrite { length: u64, @@ -360,8 +313,6 @@ mod tests { } } - #[cfg_attr(not(target_arch = "wasm32"), async_trait)] - #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl BlockWrite for Arc> { async fn write_once(&self, _: u64, _: AsyncBody) -> Result<()> { Ok(()) diff --git a/core/src/raw/oio/write/exact_buf_write.rs b/core/src/raw/oio/write/exact_buf_write.rs index 20a19e67256c..cd95055c8316 100644 --- a/core/src/raw/oio/write/exact_buf_write.rs +++ b/core/src/raw/oio/write/exact_buf_write.rs @@ -17,9 +17,6 @@ use bytes::Bytes; use std::cmp::min; -use std::task::ready; -use std::task::Context; -use std::task::Poll; use crate::raw::oio::WriteBuf; use crate::raw::*; @@ -56,31 +53,31 @@ impl ExactBufWriter { } impl oio::Write for ExactBufWriter { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> Result { if self.buffer.len() >= self.buffer_size { let bs = self.buffer.bytes(self.buffer.remaining()); - let written = ready!(self.inner.poll_write(cx, bs)?); + let written = self.inner.write(bs).await?; self.buffer.advance(written); } let remaining = min(self.buffer_size - self.buffer.len(), bs.len()); self.buffer.push(bs.slice(0..remaining)); - Poll::Ready(Ok(remaining)) + Ok(remaining) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn close(&mut self) -> Result<()> { while !self.buffer.is_empty() { let bs = self.buffer.bytes(self.buffer.remaining()); - let n = ready!(self.inner.poll_write(cx, bs))?; + let n = self.inner.write(bs).await?; self.buffer.advance(n); } - self.inner.poll_close(cx) + self.inner.close().await } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn abort(&mut self) -> Result<()> { self.buffer.clear(); - self.inner.poll_abort(cx) + self.inner.abort().await } } @@ -97,26 +94,25 @@ mod tests { use super::*; use crate::raw::oio::Write; - use crate::raw::oio::WriteExt; struct MockWriter { buf: Vec, } impl Write for MockWriter { - fn poll_write(&mut self, _: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> Result { debug!("test_fuzz_exact_buf_writer: flush size: {}", &bs.len()); self.buf.extend_from_slice(&bs); - Poll::Ready(Ok(bs.len())) + Ok(bs.len()) } - fn poll_close(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + async fn close(&mut self) -> Result<()> { + Ok(()) } - fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + async fn abort(&mut self) -> Result<()> { + Ok(()) } } diff --git a/core/src/raw/oio/write/mod.rs b/core/src/raw/oio/write/mod.rs index 5f437e2525fa..66ffa2f4c0f0 100644 --- a/core/src/raw/oio/write/mod.rs +++ b/core/src/raw/oio/write/mod.rs @@ -19,7 +19,6 @@ mod api; pub use api::BlockingWrite; pub use api::BlockingWriter; pub use api::Write; -pub use api::WriteExt; pub use api::WriteOperation; pub use api::Writer; diff --git a/core/src/raw/oio/write/multipart_write.rs b/core/src/raw/oio/write/multipart_write.rs index cdf2d5408c84..10b260210a7f 100644 --- a/core/src/raw/oio/write/multipart_write.rs +++ b/core/src/raw/oio/write/multipart_write.rs @@ -17,11 +17,10 @@ use std::pin::Pin; use std::sync::Arc; -use std::task::ready; + use std::task::Context; use std::task::Poll; -use async_trait::async_trait; use bytes::Bytes; use futures::Future; use futures::FutureExt; @@ -61,15 +60,16 @@ use crate::*; /// - Don't need initialization before writing. /// - Block ID is generated by caller `BlockWrite` instead of services. /// - Complete block by an ordered block id list. -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait MultipartWrite: Send + Sync + Unpin + 'static { /// write_once is used to write the data to underlying storage at once. /// /// MultipartWriter will call this API when: /// /// - All the data has been written to the buffer and we can perform the upload at once. - async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()>; + #[cfg(not(target_arch = "wasm32"))] + fn write_once(&self, size: u64, body: AsyncBody) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn write_once(&self, size: u64, body: AsyncBody) -> impl Future>; /// initiate_part will call start a multipart upload and return the upload id. /// @@ -78,7 +78,10 @@ pub trait MultipartWrite: Send + Sync + Unpin + 'static { /// - the total size of data is unknown. /// - the total size of data is known, but the size of current write /// is less then the total size. - async fn initiate_part(&self) -> Result; + #[cfg(not(target_arch = "wasm32"))] + fn initiate_part(&self) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn initiate_part(&self) -> impl Future>; /// write_part will write a part of the data and returns the result /// [`MultipartPart`]. @@ -87,20 +90,43 @@ pub trait MultipartWrite: Send + Sync + Unpin + 'static { /// order. /// /// - part_number is the index of the part, starting from 0. - async fn write_part( + #[cfg(not(target_arch = "wasm32"))] + fn write_part( + &self, + upload_id: &str, + part_number: usize, + size: u64, + body: AsyncBody, + ) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn write_part( &self, upload_id: &str, part_number: usize, size: u64, body: AsyncBody, - ) -> Result; + ) -> impl Future>; /// complete_part will complete the multipart upload to build the final /// file. - async fn complete_part(&self, upload_id: &str, parts: &[MultipartPart]) -> Result<()>; + #[cfg(not(target_arch = "wasm32"))] + fn complete_part( + &self, + upload_id: &str, + parts: &[MultipartPart], + ) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn complete_part( + &self, + upload_id: &str, + parts: &[MultipartPart], + ) -> impl Future>; /// abort_part will cancel the multipart upload and purge all data. - async fn abort_part(&self, upload_id: &str) -> Result<()>; + #[cfg(not(target_arch = "wasm32"))] + fn abort_part(&self, upload_id: &str) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn abort_part(&self, upload_id: &str) -> impl Future>; } /// The result of [`MultipartWrite::write_part`]. @@ -166,7 +192,6 @@ impl WritePartFuture { /// MultipartWriter will implements [`Write`] based on multipart /// uploads. pub struct MultipartWriter { - state: State, w: Arc, upload_id: Option>, @@ -176,28 +201,14 @@ pub struct MultipartWriter { next_part_number: usize, } -enum State { - Idle, - Init(BoxedStaticFuture>), - Close(BoxedStaticFuture>), - Abort(BoxedStaticFuture>), -} - /// # Safety /// /// wasm32 is a special target that we only have one event-loop for this state. -unsafe impl Send for State {} -/// # Safety -/// -/// We will only take `&mut Self` reference for State. -unsafe impl Sync for State {} impl MultipartWriter { /// Create a new MultipartWriter. pub fn new(inner: W, concurrent: usize) -> Self { Self { - state: State::Idle, - w: Arc::new(inner), upload_id: None, parts: Vec::new(), @@ -220,187 +231,125 @@ impl oio::Write for MultipartWriter where W: MultipartWrite, { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - loop { - match &mut self.state { - State::Idle => { - match self.upload_id.as_ref() { - Some(upload_id) => { - if self.futures.has_remaining() { - let cache = self.cache.take().expect("pending write must exist"); - let part_number = self.next_part_number; - self.next_part_number += 1; - - self.futures.push_back(WritePartFuture::new( - self.w.clone(), - upload_id.clone(), - part_number, - cache, - )); - let size = self.fill_cache(bs); - return Poll::Ready(Ok(size)); - } - - if let Some(part) = ready!(self.futures.poll_next_unpin(cx)) { - match part { - Ok(part) => { - self.parts.push(part); - } - Err((part_number, bytes, err)) => { - self.futures.push_front(WritePartFuture::new( - self.w.clone(), - upload_id.clone(), - part_number, - bytes, - )); - return Poll::Ready(Err(err)); - } - } - } - } - None => { - // Fill cache with the first write. - if self.cache.is_none() { - let size = self.fill_cache(bs); - return Poll::Ready(Ok(size)); - } - - let w = self.w.clone(); - self.state = - State::Init(Box::pin(async move { w.initiate_part().await })); - } - } - } - State::Init(fut) => { - let upload_id = ready!(fut.as_mut().poll(cx)); - // Make sure the future is dropped after it returned ready. - self.state = State::Idle; - self.upload_id = Some(Arc::new(upload_id?)); - } - State::Close(_) => { - unreachable!("MultipartWriter must not go into State::Close during poll_write") - } - State::Abort(_) => { - unreachable!("MultipartWriter must not go into State::Abort during poll_write") + async fn write(&mut self, bs: Bytes) -> Result { + let upload_id = match self.upload_id.clone() { + Some(v) => v, + None => { + // Fill cache with the first write. + if self.cache.is_none() { + let size = self.fill_cache(bs); + return Ok(size); } + + let upload_id = self.w.initiate_part().await?; + let upload_id = Arc::new(upload_id); + self.upload_id = Some(upload_id.clone()); + upload_id } - } - } + }; - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - match &mut self.state { - State::Idle => { - match self.upload_id.as_ref() { - Some(upload_id) => { - // futures queue is empty and cache is consumed, we can complete the upload. - if self.futures.is_empty() && self.cache.is_none() { - let w = self.w.clone(); - let upload_id = upload_id.clone(); - let parts = self.parts.clone(); - - self.state = State::Close(Box::pin(async move { - w.complete_part(&upload_id, &parts).await - })); - continue; - } - - if self.futures.has_remaining() { - // This must be the final task. - if let Some(cache) = self.cache.take() { - let part_number = self.next_part_number; - self.next_part_number += 1; - - self.futures.push_back(WritePartFuture::new( - self.w.clone(), - upload_id.clone(), - part_number, - cache, - )); - } - } - - if let Some(part) = ready!(self.futures.poll_next_unpin(cx)) { - match part { - Ok(part) => { - self.parts.push(part); - } - Err((part_number, bytes, err)) => { - self.futures.push_front(WritePartFuture::new( - self.w.clone(), - upload_id.clone(), - part_number, - bytes, - )); - return Poll::Ready(Err(err)); - } - } - } - } - None => { - let w = self.w.clone(); - let (size, body) = match self.cache.clone() { - Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)), - None => (0, AsyncBody::Empty), - }; - // Call write_once if there is no upload_id. - self.state = State::Close(Box::pin(async move { - w.write_once(size as u64, body).await - })); - } + if self.futures.has_remaining() { + let cache = self.cache.take().expect("pending write must exist"); + let part_number = self.next_part_number; + self.next_part_number += 1; + + self.futures.push_back(WritePartFuture::new( + self.w.clone(), + upload_id.clone(), + part_number, + cache, + )); + let size = self.fill_cache(bs); + return Ok(size); + } + + if let Some(part) = self.futures.next().await { + match part { + Ok(part) => { + self.parts.push(part); + continue; + } + Err((part_number, bytes, err)) => { + self.futures.push_front(WritePartFuture::new( + self.w.clone(), + upload_id.clone(), + part_number, + bytes, + )); + return Err(err); } - } - State::Close(fut) => { - let res = futures::ready!(fut.as_mut().poll(cx)); - self.state = State::Idle; - // We should check res first before clean up cache. - res?; - self.cache = None; - - return Poll::Ready(Ok(())); - } - State::Init(_) => { - unreachable!("MultipartWriter must not go into State::Init during poll_close") - } - State::Abort(_) => { - unreachable!("MultipartWriter must not go into State::Abort during poll_close") } } } } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { + async fn close(&mut self) -> Result<()> { + let upload_id = match self.upload_id.clone() { + Some(v) => v, + None => { + let (size, body) = match self.cache.clone() { + Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)), + None => (0, AsyncBody::Empty), + }; + // Call write_once if there is no upload_id. + self.w.write_once(size as u64, body).await?; + self.cache = None; + return Ok(()); + } + }; + loop { - match &mut self.state { - State::Idle => { - let w = self.w.clone(); - match self.upload_id.clone() { - Some(upload_id) => { - self.futures.clear(); - self.state = - State::Abort(Box::pin( - async move { w.abort_part(&upload_id).await }, - )); - } - None => { - return Poll::Ready(Ok(())); - } - } - } - State::Abort(fut) => { - let res = futures::ready!(fut.as_mut().poll(cx)); - self.state = State::Idle; - return Poll::Ready(res); - } - State::Init(_) => { - unreachable!("MultipartWriter must not go into State::Init during poll_abort") + // futures queue is empty and cache is consumed, we can complete the upload. + if self.futures.is_empty() && self.cache.is_none() { + return self.w.complete_part(&upload_id, &self.parts).await; + } + + if self.futures.has_remaining() { + // This must be the final task. + if let Some(cache) = self.cache.take() { + let part_number = self.next_part_number; + self.next_part_number += 1; + + self.futures.push_back(WritePartFuture::new( + self.w.clone(), + upload_id.clone(), + part_number, + cache, + )); } - State::Close(_) => { - unreachable!("MultipartWriter must not go into State::Close during poll_abort") + } + + if let Some(part) = self.futures.next().await { + match part { + Ok(part) => { + self.parts.push(part); + continue; + } + Err((part_number, bytes, err)) => { + self.futures.push_front(WritePartFuture::new( + self.w.clone(), + upload_id.clone(), + part_number, + bytes, + )); + return Err(err); + } } } } } + + async fn abort(&mut self) -> Result<()> { + let Some(upload_id) = self.upload_id.clone() else { + return Ok(()); + }; + + self.futures.clear(); + self.w.abort_part(&upload_id).await?; + self.cache = None; + Ok(()) + } } #[cfg(test)] @@ -413,7 +362,7 @@ mod tests { use rand::RngCore; use super::*; - use crate::raw::oio::WriteExt; + use crate::raw::oio::Write; struct TestWrite { upload_id: String, @@ -433,8 +382,6 @@ mod tests { } } - #[cfg_attr(not(target_arch = "wasm32"), async_trait)] - #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl MultipartWrite for Arc> { async fn write_once(&self, size: u64, _: AsyncBody) -> Result<()> { self.lock().unwrap().length += size; diff --git a/core/src/raw/oio/write/one_shot_write.rs b/core/src/raw/oio/write/one_shot_write.rs index ab8ad41565b8..56116b8e41f2 100644 --- a/core/src/raw/oio/write/one_shot_write.rs +++ b/core/src/raw/oio/write/one_shot_write.rs @@ -15,11 +15,8 @@ // specific language governing permissions and limitations // under the License. -use std::task::ready; -use std::task::Context; -use std::task::Poll; +use std::future::Future; -use async_trait::async_trait; use bytes::Bytes; use crate::raw::*; @@ -31,102 +28,56 @@ use crate::*; /// For example, S3 `PUT Object` and fs `write_all`. /// /// The layout after adopting [`OneShotWrite`]: -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait OneShotWrite: Send + Sync + Unpin + 'static { /// write_once write all data at once. /// /// Implementations should make sure that the data is written correctly at once. - async fn write_once(&self, bs: Bytes) -> Result<()>; + #[cfg(not(target_arch = "wasm32"))] + fn write_once(&self, bs: Bytes) -> impl Future> + Send; + #[cfg(target_arch = "wasm32")] + fn write_once(&self, bs: Bytes) -> impl Future>; } /// OneShotWrite is used to implement [`Write`] based on one shot. pub struct OneShotWriter { - state: State, + inner: W, buffer: Option, } -enum State { - Idle(Option), - Write(BoxedStaticFuture<(W, Result<()>)>), -} - -/// # Safety -/// -/// wasm32 is a special target that we only have one event-loop for this state. -unsafe impl Send for State {} - -/// # Safety -/// -/// We will only take `&mut Self` reference for State. -unsafe impl Sync for State {} - impl OneShotWriter { /// Create a new one shot writer. pub fn new(inner: W) -> Self { Self { - state: State::Idle(Some(inner)), + inner, buffer: None, } } } impl oio::Write for OneShotWriter { - fn poll_write(&mut self, _: &mut Context<'_>, bs: Bytes) -> Poll> { - match &mut self.state { - State::Idle(_) => match &self.buffer { - Some(_) => Poll::Ready(Err(Error::new( - ErrorKind::Unsupported, - "OneShotWriter doesn't support multiple write", - ))), - None => { - let size = bs.len(); - self.buffer = Some(bs); - Poll::Ready(Ok(size)) - } - }, - State::Write(_) => { - unreachable!("OneShotWriter must not go into State::Write during poll_write") + async fn write(&mut self, bs: Bytes) -> Result { + match &self.buffer { + Some(_) => Err(Error::new( + ErrorKind::Unsupported, + "OneShotWriter doesn't support multiple write", + )), + None => { + let size = bs.len(); + self.buffer = Some(bs); + Ok(size) } } } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match &mut self.state { - State::Idle(w) => { - let w = w.take().expect("writer must be valid"); - - match self.buffer.clone() { - Some(bs) => { - let fut = Box::pin(async move { - let res = w.write_once(bs).await; - - (w, res) - }); - self.state = State::Write(fut); - } - None => { - let fut = Box::pin(async move { - let res = w.write_once(Bytes::new()).await; - - (w, res) - }); - self.state = State::Write(fut); - } - }; - } - State::Write(fut) => { - let (w, res) = ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(w)); - return Poll::Ready(res); - } - } + async fn close(&mut self) -> Result<()> { + match self.buffer.clone() { + Some(bs) => self.inner.write_once(bs).await, + None => self.inner.write_once(Bytes::new()).await, } } - fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll> { + async fn abort(&mut self) -> Result<()> { self.buffer = None; - Poll::Ready(Ok(())) + Ok(()) } } diff --git a/core/src/raw/oio/write/range_write.rs b/core/src/raw/oio/write/range_write.rs index c63144e10474..f766943fbbb3 100644 --- a/core/src/raw/oio/write/range_write.rs +++ b/core/src/raw/oio/write/range_write.rs @@ -17,11 +17,10 @@ use std::pin::Pin; use std::sync::Arc; -use std::task::ready; + use std::task::Context; use std::task::Poll; -use async_trait::async_trait; use bytes::Bytes; use futures::Future; use futures::FutureExt; @@ -59,39 +58,37 @@ use crate::*; /// - Must be a http service that could accept `AsyncBody`. /// - Need initialization before writing. /// - Writing data based on range: `offset`, `size`. -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] pub trait RangeWrite: Send + Sync + Unpin + 'static { /// write_once is used to write the data to underlying storage at once. /// /// RangeWriter will call this API when: /// /// - All the data has been written to the buffer and we can perform the upload at once. - async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()>; + fn write_once(&self, size: u64, body: AsyncBody) -> impl Future> + Send; /// Initiate range the range write, the returning value is the location. - async fn initiate_range(&self) -> Result; + fn initiate_range(&self) -> impl Future> + Send; /// write_range will write a range of data. - async fn write_range( + fn write_range( &self, location: &str, offset: u64, size: u64, body: AsyncBody, - ) -> Result<()>; + ) -> impl Future> + Send; /// complete_range will complete the range write by uploading the last chunk. - async fn complete_range( + fn complete_range( &self, location: &str, offset: u64, size: u64, body: AsyncBody, - ) -> Result<()>; + ) -> impl Future> + Send; /// abort_range will abort the range write by abort all already uploaded data. - async fn abort_range(&self, location: &str) -> Result<()>; + fn abort_range(&self, location: &str) -> impl Future> + Send; } /// WritePartResult is the result returned by [`WriteRangeFuture`]. @@ -149,31 +146,12 @@ pub struct RangeWriter { futures: ConcurrentFutures, w: Arc, - state: State, -} - -enum State { - Idle, - Init(BoxedStaticFuture>), - Complete(BoxedStaticFuture>), - Abort(BoxedStaticFuture>), } -/// # Safety -/// -/// wasm32 is a special target that we only have one event-loop for this state. -unsafe impl Send for State {} - -/// # Safety -/// -/// We will only take `&mut Self` reference for State. -unsafe impl Sync for State {} - impl RangeWriter { /// Create a new MultipartWriter. pub fn new(inner: W, concurrent: usize) -> Self { Self { - state: State::Idle, w: Arc::new(inner), futures: ConcurrentFutures::new(1.max(concurrent)), @@ -193,163 +171,101 @@ impl RangeWriter { } impl oio::Write for RangeWriter { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - loop { - match &mut self.state { - State::Idle => { - match self.location.clone() { - Some(location) => { - if self.futures.has_remaining() { - let cache = self.buffer.take().expect("cache must be valid"); - let offset = self.next_offset; - self.next_offset += cache.len() as u64; - self.futures.push_back(WriteRangeFuture::new( - self.w.clone(), - location, - offset, - cache, - )); - - let size = self.fill_cache(bs); - return Poll::Ready(Ok(size)); - } - - if let Some(Err((offset, bytes, err))) = - ready!(self.futures.poll_next_unpin(cx)) - { - self.futures.push_front(WriteRangeFuture::new( - self.w.clone(), - location, - offset, - bytes, - )); - return Poll::Ready(Err(err)); - } - } - None => { - // Fill cache with the first write. - if self.buffer.is_none() { - let size = self.fill_cache(bs); - return Poll::Ready(Ok(size)); - } - - let w = self.w.clone(); - self.state = - State::Init(Box::pin(async move { w.initiate_range().await })); - } - } - } - State::Init(fut) => { - let res = ready!(fut.poll_unpin(cx)); - self.state = State::Idle; - self.location = Some(Arc::new(res?)); - } - State::Complete(_) => { - unreachable!("RangeWriter must not go into State::Complete during poll_write") - } - State::Abort(_) => { - unreachable!("RangeWriter must not go into State::Abort during poll_write") + async fn write(&mut self, bs: Bytes) -> Result { + let location = match self.location.clone() { + Some(location) => location, + None => { + // Fill cache with the first write. + if self.buffer.is_none() { + let size = self.fill_cache(bs); + return Ok(size); } + + let location = self.w.initiate_range().await?; + let location = Arc::new(location); + self.location = Some(location.clone()); + location } - } - } + }; - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - match &mut self.state { - State::Idle => { - let w = self.w.clone(); - match self.location.clone() { - Some(location) => { - if !self.futures.is_empty() { - while let Some(result) = ready!(self.futures.poll_next_unpin(cx)) { - if let Err((offset, bytes, err)) = result { - self.futures.push_front(WriteRangeFuture::new( - self.w.clone(), - location, - offset, - bytes, - )); - return Poll::Ready(Err(err)); - }; - } - } - match self.buffer.take() { - Some(bs) => { - let offset = self.next_offset; - self.state = State::Complete(Box::pin(async move { - w.complete_range( - &location, - offset, - bs.len() as u64, - AsyncBody::ChunkedBytes(bs), - ) - .await - })); - } - None => { - unreachable!("It's must be bug that RangeWrite is in State::Idle with no cache but has location") - } - } - } - None => { - let w = self.w.clone(); - let (size, body) = match self.buffer.clone() { - Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)), - None => (0, AsyncBody::Empty), - }; - // Call write_once if there is no data in buffer and no location. - - self.state = State::Complete(Box::pin(async move { - w.write_once(size as u64, body).await - })); - } - } - } - State::Init(_) => { - unreachable!("RangeWriter must not go into State::Init during poll_close") - } - State::Complete(fut) => { - let res = ready!(fut.poll_unpin(cx)); - self.state = State::Idle; - return Poll::Ready(res); - } - State::Abort(_) => { - unreachable!("RangeWriter must not go into State::Abort during poll_close") - } + if self.futures.has_remaining() { + let cache = self.buffer.take().expect("cache must be valid"); + let offset = self.next_offset; + self.next_offset += cache.len() as u64; + self.futures.push_back(WriteRangeFuture::new( + self.w.clone(), + location, + offset, + cache, + )); + + let size = self.fill_cache(bs); + return Ok(size); + } + + if let Some(Err((offset, bytes, err))) = self.futures.next().await { + self.futures.push_front(WriteRangeFuture::new( + self.w.clone(), + location, + offset, + bytes, + )); + return Err(err); } } } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match &mut self.state { - State::Idle => match self.location.clone() { - Some(location) => { - let w = self.w.clone(); - self.futures.clear(); - self.state = - State::Abort(Box::pin(async move { w.abort_range(&location).await })); - } - None => return Poll::Ready(Ok(())), - }, - State::Init(_) => { - unreachable!("RangeWriter must not go into State::Init during poll_close") - } - State::Complete(_) => { - unreachable!("RangeWriter must not go into State::Complete during poll_close") - } - State::Abort(fut) => { - let res = ready!(fut.poll_unpin(cx)); - self.state = State::Idle; - // We should check res first before clean up cache. - res?; - - self.buffer = None; - return Poll::Ready(Ok(())); - } + async fn close(&mut self) -> Result<()> { + let Some(location) = self.location.clone() else { + let (size, body) = match self.buffer.clone() { + Some(cache) => (cache.len(), AsyncBody::ChunkedBytes(cache)), + None => (0, AsyncBody::Empty), + }; + // Call write_once if there is no data in buffer and no location. + return self.w.write_once(size as u64, body).await; + }; + + if !self.futures.is_empty() { + while let Some(result) = self.futures.next().await { + if let Err((offset, bytes, err)) = result { + self.futures.push_front(WriteRangeFuture::new( + self.w.clone(), + location, + offset, + bytes, + )); + return Err(err); + }; } } + + if let Some(buffer) = self.buffer.clone() { + let offset = self.next_offset; + self.w + .complete_range( + &location, + offset, + buffer.len() as u64, + AsyncBody::ChunkedBytes(buffer), + ) + .await?; + self.buffer = None; + } + + Ok(()) + } + + async fn abort(&mut self) -> Result<()> { + let Some(location) = self.location.clone() else { + return Ok(()); + }; + + self.futures.clear(); + self.w.abort_range(&location).await?; + // Clean cache when abort_range returns success. + self.buffer = None; + Ok(()) } } @@ -364,7 +280,7 @@ mod tests { use rand::RngCore; use super::*; - use crate::raw::oio::WriteExt; + use crate::raw::oio::Write; struct TestWrite { length: u64, @@ -382,8 +298,6 @@ mod tests { } } - #[cfg_attr(not(target_arch = "wasm32"), async_trait)] - #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl RangeWrite for Arc> { async fn write_once(&self, size: u64, _: AsyncBody) -> Result<()> { let mut test = self.lock().unwrap(); diff --git a/core/src/services/alluxio/writer.rs b/core/src/services/alluxio/writer.rs index 1c5c1ed5cfd6..eed45f2cab08 100644 --- a/core/src/services/alluxio/writer.rs +++ b/core/src/services/alluxio/writer.rs @@ -16,13 +16,8 @@ // under the License. use std::sync::Arc; -use std::task::ready; -use std::task::Context; -use std::task::Poll; -use async_trait::async_trait; use bytes::Bytes; -use futures::future::BoxFuture; use super::core::AlluxioCore; @@ -32,24 +27,17 @@ use crate::*; pub type AlluxioWriters = AlluxioWriter; pub struct AlluxioWriter { - state: State, + core: Arc, _op: OpWrite, path: String, stream_id: Option, } -enum State { - Idle(Option>), - Init(BoxFuture<'static, (Arc, Result)>), - Write(BoxFuture<'static, (Arc, Result)>), - Close(BoxFuture<'static, (Arc, Result<()>)>), -} - impl AlluxioWriter { pub fn new(core: Arc, _op: OpWrite, path: String) -> Self { AlluxioWriter { - state: State::Idle(Some(core)), + core, _op, path, stream_id: None, @@ -57,95 +45,30 @@ impl AlluxioWriter { } } -/// # Safety -/// -/// We will only take `&mut Self` reference for State. -unsafe impl Sync for State {} - -#[async_trait] impl oio::Write for AlluxioWriter { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - loop { - match &mut self.state { - State::Idle(w) => match self.stream_id.as_ref() { - Some(stream_id) => { - let cb = oio::ChunkedBytes::from_vec(vec![bs.clone()]); - - let stream_id = *stream_id; - - let w = w.take().expect("writer must be valid"); - - self.state = State::Write(Box::pin(async move { - let part = w.write(stream_id, AsyncBody::ChunkedBytes(cb)).await; - - (w, part) - })); - } - None => { - let path = self.path.clone(); - let w = w.take().expect("writer must be valid"); - self.state = State::Init(Box::pin(async move { - let upload_id = w.create_file(&path).await; - (w, upload_id) - })); - } - }, - State::Init(fut) => { - let (w, stream_id) = ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(w)); - self.stream_id = Some(stream_id?); - } - State::Write(fut) => { - let (w, part) = ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(w)); - return Poll::Ready(Ok(part?)); - } - State::Close(_) => { - unreachable!("MultipartWriter must not go into State::Close during poll_write") - } + async fn write(&mut self, bs: Bytes) -> Result { + let stream_id = match self.stream_id { + Some(stream_id) => stream_id, + None => { + let stream_id = self.core.create_file(&self.path).await?; + self.stream_id = Some(stream_id); + stream_id } - } + }; + self.core.write(stream_id, AsyncBody::Bytes(bs)).await } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match &mut self.state { - State::Idle(w) => { - let w = w.take().expect("writer must be valid"); - match self.stream_id { - Some(stream_id) => { - self.state = State::Close(Box::pin(async move { - let res = w.close(stream_id).await; - (w, res) - })); - } - None => { - return Poll::Ready(Ok(())); - } - } - } - State::Close(fut) => { - let (w, res) = futures::ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(w)); - - res?; - - return Poll::Ready(Ok(())); - } - State::Init(_) => { - unreachable!("AlluxioWriter must not go into State::Init during poll_close") - } - State::Write(_) => unreachable! { - "AlluxioWriter must not go into State::Write during poll_close" - }, - } - } + async fn close(&mut self) -> Result<()> { + let Some(stream_id) = self.stream_id else { + return Ok(()); + }; + self.core.close(stream_id).await } - fn poll_abort(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Err(Error::new( + async fn abort(&mut self) -> Result<()> { + Err(Error::new( ErrorKind::Unsupported, "AlluxioWriter doesn't support abort", - ))) + )) } } diff --git a/core/src/services/azblob/writer.rs b/core/src/services/azblob/writer.rs index e0cbc72f9aed..84025362d4a4 100644 --- a/core/src/services/azblob/writer.rs +++ b/core/src/services/azblob/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use http::StatusCode; use uuid::Uuid; @@ -43,8 +42,6 @@ impl AzblobWriter { } } -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl oio::AppendWrite for AzblobWriter { async fn offset(&self) -> Result { let resp = self @@ -111,9 +108,6 @@ impl oio::AppendWrite for AzblobWriter { } } -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[async_trait] impl oio::BlockWrite for AzblobWriter { async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let mut req: http::Request = diff --git a/core/src/services/azdls/writer.rs b/core/src/services/azdls/writer.rs index 32f1db96acba..6fc3709d04a1 100644 --- a/core/src/services/azdls/writer.rs +++ b/core/src/services/azdls/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -42,7 +41,6 @@ impl AzdlsWriter { } } -#[async_trait] impl oio::OneShotWrite for AzdlsWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let mut req = @@ -89,7 +87,6 @@ impl oio::OneShotWrite for AzdlsWriter { } } -#[async_trait] impl oio::AppendWrite for AzdlsWriter { async fn offset(&self) -> Result { let resp = self.core.azdls_get_properties(&self.path).await?; diff --git a/core/src/services/azfile/writer.rs b/core/src/services/azfile/writer.rs index d9253939ea24..e2c3dde86eb6 100644 --- a/core/src/services/azfile/writer.rs +++ b/core/src/services/azfile/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -40,7 +39,6 @@ impl AzfileWriter { } } -#[async_trait] impl oio::OneShotWrite for AzfileWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let resp = self @@ -65,7 +63,7 @@ impl oio::OneShotWrite for AzfileWriter { .azfile_update(&self.path, bs.len() as u64, 0, AsyncBody::Bytes(bs)) .await?; let status = resp.status(); - return match status { + match status { StatusCode::OK | StatusCode::CREATED => { resp.into_body().consume().await?; Ok(()) @@ -73,11 +71,10 @@ impl oio::OneShotWrite for AzfileWriter { _ => Err(parse_error(resp) .await? .with_operation("Backend::azfile_update")), - }; + } } } -#[async_trait] impl oio::AppendWrite for AzfileWriter { async fn offset(&self) -> Result { let resp = self.core.azfile_get_file_properties(&self.path).await?; diff --git a/core/src/services/b2/writer.rs b/core/src/services/b2/writer.rs index abbdc4dc36b7..78cb8a9fc433 100644 --- a/core/src/services/b2/writer.rs +++ b/core/src/services/b2/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use http::StatusCode; use super::core::B2Core; @@ -46,7 +45,6 @@ impl B2Writer { } } -#[async_trait] impl oio::MultipartWrite for B2Writer { async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let resp = self diff --git a/core/src/services/chainsafe/writer.rs b/core/src/services/chainsafe/writer.rs index b1c960afa432..79b58e324f0a 100644 --- a/core/src/services/chainsafe/writer.rs +++ b/core/src/services/chainsafe/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -44,7 +43,6 @@ impl ChainsafeWriter { } } -#[async_trait] impl oio::OneShotWrite for ChainsafeWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let resp = self.core.upload_object(&self.path, bs).await?; diff --git a/core/src/services/cos/writer.rs b/core/src/services/cos/writer.rs index a05472aa8aae..51f0ad15edf5 100644 --- a/core/src/services/cos/writer.rs +++ b/core/src/services/cos/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use http::StatusCode; use super::core::*; @@ -44,7 +43,6 @@ impl CosWriter { } } -#[async_trait] impl oio::MultipartWrite for CosWriter { async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let mut req = self @@ -167,7 +165,6 @@ impl oio::MultipartWrite for CosWriter { } } -#[async_trait] impl oio::AppendWrite for CosWriter { async fn offset(&self) -> Result { let resp = self diff --git a/core/src/services/dbfs/writer.rs b/core/src/services/dbfs/writer.rs index 63e4cb7f8f83..47442901335c 100644 --- a/core/src/services/dbfs/writer.rs +++ b/core/src/services/dbfs/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -40,7 +39,6 @@ impl DbfsWriter { } } -#[async_trait] impl oio::OneShotWrite for DbfsWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let size = bs.len(); diff --git a/core/src/services/dropbox/writer.rs b/core/src/services/dropbox/writer.rs index 163757fad7d5..de1f594cd6b0 100644 --- a/core/src/services/dropbox/writer.rs +++ b/core/src/services/dropbox/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -38,7 +37,6 @@ impl DropboxWriter { } } -#[async_trait] impl oio::OneShotWrite for DropboxWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let resp = self diff --git a/core/src/services/fs/writer.rs b/core/src/services/fs/writer.rs index 3ec8c72c0778..c035ff1bf94b 100644 --- a/core/src/services/fs/writer.rs +++ b/core/src/services/fs/writer.rs @@ -16,16 +16,10 @@ // under the License. use bytes::Bytes; + use std::io::Write; use std::path::PathBuf; -use std::pin::Pin; -use std::task::ready; -use std::task::Context; -use std::task::Poll; - -use futures::future::BoxFuture; -use futures::FutureExt; -use tokio::io::AsyncWrite; + use tokio::io::AsyncWriteExt; use crate::raw::*; @@ -36,7 +30,6 @@ pub struct FsWriter { tmp_path: Option, f: Option, - fut: Option>>, } impl FsWriter { @@ -46,7 +39,6 @@ impl FsWriter { tmp_path, f: Some(f), - fut: None, } } } @@ -57,60 +49,35 @@ impl FsWriter { unsafe impl Sync for FsWriter {} impl oio::Write for FsWriter { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> Result { let f = self.f.as_mut().expect("FsWriter must be initialized"); - Pin::new(f).poll_write(cx, &bs).map_err(new_std_io_error) + f.write(&bs).await.map_err(new_std_io_error) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - if let Some(fut) = self.fut.as_mut() { - let res = ready!(fut.poll_unpin(cx)); - self.fut = None; - return Poll::Ready(res); - } + async fn close(&mut self) -> Result<()> { + let f = self.f.as_mut().expect("FsWriter must be initialized"); + f.flush().await.map_err(new_std_io_error)?; + f.sync_all().await.map_err(new_std_io_error)?; - let mut f = self.f.take().expect("FsWriter must be initialized"); - let tmp_path = self.tmp_path.clone(); - let target_path = self.target_path.clone(); - self.fut = Some(Box::pin(async move { - f.flush().await.map_err(new_std_io_error)?; - f.sync_all().await.map_err(new_std_io_error)?; - - if let Some(tmp_path) = &tmp_path { - tokio::fs::rename(tmp_path, &target_path) - .await - .map_err(new_std_io_error)?; - } - - Ok(()) - })); + if let Some(tmp_path) = &self.tmp_path { + tokio::fs::rename(tmp_path, &self.target_path) + .await + .map_err(new_std_io_error)?; } + Ok(()) } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - if let Some(fut) = self.fut.as_mut() { - let res = ready!(fut.poll_unpin(cx)); - self.fut = None; - return Poll::Ready(res); - } - - let _ = self.f.take().expect("FsWriter must be initialized"); - let tmp_path = self.tmp_path.clone(); - self.fut = Some(Box::pin(async move { - if let Some(tmp_path) = &tmp_path { - tokio::fs::remove_file(tmp_path) - .await - .map_err(new_std_io_error) - } else { - Err(Error::new( - ErrorKind::Unsupported, - "Fs doesn't support abort if atomic_write_dir is not set", - )) - } - })); + async fn abort(&mut self) -> Result<()> { + if let Some(tmp_path) = &self.tmp_path { + tokio::fs::remove_file(tmp_path) + .await + .map_err(new_std_io_error) + } else { + Err(Error::new( + ErrorKind::Unsupported, + "Fs doesn't support abort if atomic_write_dir is not set", + )) } } } diff --git a/core/src/services/ftp/writer.rs b/core/src/services/ftp/writer.rs index c0b91560df1d..505176a46d27 100644 --- a/core/src/services/ftp/writer.rs +++ b/core/src/services/ftp/writer.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; use bytes::Bytes; use futures::AsyncWriteExt; @@ -48,7 +47,6 @@ impl FtpWriter { /// We will only take `&mut Self` reference for FtpWriter. unsafe impl Sync for FtpWriter {} -#[async_trait] impl oio::OneShotWrite for FtpWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let mut ftp_stream = self.backend.ftp_connect(Operation::Write).await?; diff --git a/core/src/services/gcs/writer.rs b/core/src/services/gcs/writer.rs index 99b6217f2adf..0bf13a4498c3 100644 --- a/core/src/services/gcs/writer.rs +++ b/core/src/services/gcs/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use http::StatusCode; use super::core::GcsCore; @@ -43,8 +42,6 @@ impl GcsWriter { } } -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl oio::RangeWrite for GcsWriter { async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let mut req = self.core.gcs_insert_object_request( diff --git a/core/src/services/gdrive/writer.rs b/core/src/services/gdrive/writer.rs index 57e03a84771d..0ebbd809803b 100644 --- a/core/src/services/gdrive/writer.rs +++ b/core/src/services/gdrive/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -46,8 +45,6 @@ impl GdriveWriter { } } -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl oio::OneShotWrite for GdriveWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let size = bs.len(); diff --git a/core/src/services/ghac/writer.rs b/core/src/services/ghac/writer.rs index 189f63f8e04d..3db0ea597975 100644 --- a/core/src/services/ghac/writer.rs +++ b/core/src/services/ghac/writer.rs @@ -15,13 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::task::ready; -use std::task::Context; -use std::task::Poll; - -use async_trait::async_trait; use bytes::Bytes; -use futures::future::BoxFuture; use super::backend::GhacBackend; use super::error::parse_error; @@ -29,7 +23,7 @@ use crate::raw::*; use crate::*; pub struct GhacWriter { - state: State, + backend: GhacBackend, cache_id: i64, size: u64, @@ -38,120 +32,51 @@ pub struct GhacWriter { impl GhacWriter { pub fn new(backend: GhacBackend, cache_id: i64) -> Self { GhacWriter { - state: State::Idle(Some(backend)), + backend, cache_id, size: 0, } } } -enum State { - Idle(Option), - Upload(BoxFuture<'static, (GhacBackend, Result)>), - Commit(BoxFuture<'static, (GhacBackend, Result<()>)>), -} - -/// # Safety -/// -/// We will only take `&mut Self` reference for State. -unsafe impl Sync for State {} - -#[async_trait] impl oio::Write for GhacWriter { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - loop { - match &mut self.state { - State::Idle(backend) => { - let backend = backend.take().expect("GhacWriter must be initialized"); - - let cache_id = self.cache_id; - let offset = self.size; - let size = bs.len(); - let bs = bs.clone(); - - let fut = async move { - let res = async { - let req = backend - .ghac_upload(cache_id, offset, size as u64, AsyncBody::Bytes(bs)) - .await?; - - let resp = backend.client.send(req).await?; + async fn write(&mut self, bs: Bytes) -> Result { + let size = bs.len(); + let offset = self.size; - if resp.status().is_success() { - resp.into_body().consume().await?; - Ok(size) - } else { - Err(parse_error(resp) - .await - .map(|err| err.with_operation("Backend::ghac_upload"))?) - } - } - .await; + let req = self + .backend + .ghac_upload(self.cache_id, offset, size as u64, AsyncBody::Bytes(bs)) + .await?; - (backend, res) - }; - self.state = State::Upload(Box::pin(fut)); - } - State::Upload(fut) => { - let (backend, res) = ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(backend)); + let resp = self.backend.client.send(req).await?; - let size = res?; - self.size += size as u64; - return Poll::Ready(Ok(size)); - } - State::Commit(_) => { - unreachable!("GhacWriter must not go into State:Commit during poll_write") - } - } + if !resp.status().is_success() { + return Err(parse_error(resp) + .await + .map(|err| err.with_operation("Backend::ghac_upload"))?); } - } - - fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll> { - self.state = State::Idle(None); - Poll::Ready(Ok(())) + resp.into_body().consume().await?; + self.size += size as u64; + Ok(size) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match &mut self.state { - State::Idle(backend) => { - let backend = backend.take().expect("GhacWriter must be initialized"); - - let cache_id = self.cache_id; - let size = self.size; - - let fut = async move { - let res = async { - let req = backend.ghac_commit(cache_id, size).await?; - let resp = backend.client.send(req).await?; - - if resp.status().is_success() { - resp.into_body().consume().await?; - Ok(()) - } else { - Err(parse_error(resp) - .await - .map(|err| err.with_operation("Backend::ghac_commit"))?) - } - } - .await; - - (backend, res) - }; - self.state = State::Commit(Box::pin(fut)); - } - State::Upload(_) => { - unreachable!("GhacWriter must not go into State:Upload during poll_close") - } - State::Commit(fut) => { - let (backend, res) = ready!(fut.as_mut().poll(cx)); - self.state = State::Idle(Some(backend)); + async fn abort(&mut self) -> Result<()> { + Ok(()) + } - return Poll::Ready(res); - } - } + async fn close(&mut self) -> Result<()> { + let req = self.backend.ghac_commit(self.cache_id, self.size).await?; + let resp = self.backend.client.send(req).await?; + + if resp.status().is_success() { + resp.into_body().consume().await?; + Ok(()) + } else { + Err(parse_error(resp) + .await + .map(|err| err.with_operation("Backend::ghac_commit"))?) } } } diff --git a/core/src/services/github/writer.rs b/core/src/services/github/writer.rs index 8801b2358ff9..0ff42415b113 100644 --- a/core/src/services/github/writer.rs +++ b/core/src/services/github/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -39,7 +38,6 @@ impl GithubWriter { } } -#[async_trait] impl oio::OneShotWrite for GithubWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let resp = self.core.upload(&self.path, bs).await?; diff --git a/core/src/services/hdfs/writer.rs b/core/src/services/hdfs/writer.rs index 0a782554c6ce..be1fc78e14d8 100644 --- a/core/src/services/hdfs/writer.rs +++ b/core/src/services/hdfs/writer.rs @@ -16,18 +16,12 @@ // under the License. use std::io::Write; -use std::pin::Pin; + use std::sync::Arc; -use std::task::ready; -use std::task::Context; -use std::task::Poll; -use async_trait::async_trait; use bytes::Bytes; -use futures::future::BoxFuture; -use futures::AsyncWrite; + use futures::AsyncWriteExt; -use futures::FutureExt; use crate::raw::*; use crate::*; @@ -37,7 +31,6 @@ pub struct HdfsWriter { tmp_path: Option, f: Option, client: Arc, - fut: Option)>>, } /// # Safety @@ -57,62 +50,36 @@ impl HdfsWriter { tmp_path, f: Some(f), client, - fut: None, } } } -#[async_trait] impl oio::Write for HdfsWriter { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { + async fn write(&mut self, bs: Bytes) -> Result { let f = self.f.as_mut().expect("HdfsWriter must be initialized"); - Pin::new(f).poll_write(cx, &bs).map_err(new_std_io_error) + f.write(&bs).await.map_err(new_std_io_error) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - if let Some(fut) = self.fut.as_mut() { - let res = ready!(fut.poll_unpin(cx)); - self.fut = None; - if let Err(e) = res.1 { - self.f = Some(res.0); - return Poll::Ready(Err(e)); - } - return Poll::Ready(Ok(())); - } - - let mut f = self.f.take().expect("HdfsWriter must be initialized"); - let tmp_path = self.tmp_path.clone(); - let target_path = self.target_path.clone(); - // Clone client to allow move into the future. - let client = self.client.clone(); - - self.fut = Some(Box::pin(async move { - if let Err(e) = f.close().await.map_err(new_std_io_error) { - // Reserve the original file handle for retry. - return (f, Err(e)); - } - - if let Some(tmp_path) = tmp_path { - if let Err(e) = client - .rename_file(&tmp_path, &target_path) - .map_err(new_std_io_error) - { - return (f, Err(e)); - } - } - - (f, Ok(())) - })); + async fn close(&mut self) -> Result<()> { + let f = self.f.as_mut().expect("HdfsWriter must be initialized"); + f.close().await.map_err(new_std_io_error)?; + + // TODO: we need to make rename async. + if let Some(tmp_path) = &self.tmp_path { + self.client + .rename_file(tmp_path, &self.target_path) + .map_err(new_std_io_error)? } + + Ok(()) } - fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Err(Error::new( + async fn abort(&mut self) -> Result<()> { + Err(Error::new( ErrorKind::Unsupported, "HdfsWriter doesn't support abort", - ))) + )) } } diff --git a/core/src/services/hdfs_native/writer.rs b/core/src/services/hdfs_native/writer.rs index 039cf3fc23e2..75ad0aeaf620 100644 --- a/core/src/services/hdfs_native/writer.rs +++ b/core/src/services/hdfs_native/writer.rs @@ -16,8 +16,6 @@ // under the License. use bytes::Bytes; -use std::task::Context; -use std::task::Poll; use hdfs_native::file::FileWriter; @@ -36,18 +34,18 @@ impl HdfsNativeWriter { } impl oio::Write for HdfsNativeWriter { - fn poll_write(&mut self, _: &mut Context<'_>, _: Bytes) -> Poll> { + async fn write(&mut self, _bs: Bytes) -> Result { todo!() } - fn poll_close(&mut self, _cx: &mut Context<'_>) -> Poll> { + async fn close(&mut self) -> Result<()> { todo!() } - fn poll_abort(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Err(Error::new( + async fn abort(&mut self) -> Result<()> { + Err(Error::new( ErrorKind::Unsupported, "HdfsNativeWriter doesn't support abort", - ))) + )) } } diff --git a/core/src/services/ipmfs/writer.rs b/core/src/services/ipmfs/writer.rs index 7f7f1e0dfce4..53f79b840c4f 100644 --- a/core/src/services/ipmfs/writer.rs +++ b/core/src/services/ipmfs/writer.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -37,7 +36,6 @@ impl IpmfsWriter { } } -#[async_trait] impl oio::OneShotWrite for IpmfsWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let resp = self.backend.ipmfs_write(&self.path, bs).await?; diff --git a/core/src/services/koofr/writer.rs b/core/src/services/koofr/writer.rs index 9ee0e31d6e1f..97c2dbdbc9b8 100644 --- a/core/src/services/koofr/writer.rs +++ b/core/src/services/koofr/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -39,7 +38,6 @@ impl KoofrWriter { } } -#[async_trait] impl oio::OneShotWrite for KoofrWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { self.core.ensure_dir_exists(&self.path).await?; diff --git a/core/src/services/obs/writer.rs b/core/src/services/obs/writer.rs index 63d7287169ae..7df2e7cbcd04 100644 --- a/core/src/services/obs/writer.rs +++ b/core/src/services/obs/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use http::StatusCode; use super::core::*; @@ -45,7 +44,6 @@ impl ObsWriter { } } -#[async_trait] impl oio::MultipartWrite for ObsWriter { async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let mut req = self @@ -168,7 +166,6 @@ impl oio::MultipartWrite for ObsWriter { } } -#[async_trait] impl oio::AppendWrite for ObsWriter { async fn offset(&self) -> Result { let resp = self diff --git a/core/src/services/onedrive/writer.rs b/core/src/services/onedrive/writer.rs index 6b79f850671e..60988f20008b 100644 --- a/core/src/services/onedrive/writer.rs +++ b/core/src/services/onedrive/writer.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -43,7 +42,6 @@ impl OneDriveWriter { } } -#[async_trait] impl oio::OneShotWrite for OneDriveWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let size = bs.len(); diff --git a/core/src/services/oss/writer.rs b/core/src/services/oss/writer.rs index c8015788d32c..066efe6a7c13 100644 --- a/core/src/services/oss/writer.rs +++ b/core/src/services/oss/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use http::StatusCode; use super::core::*; @@ -44,7 +43,6 @@ impl OssWriter { } } -#[async_trait] impl oio::MultipartWrite for OssWriter { async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let mut req = @@ -172,7 +170,6 @@ impl oio::MultipartWrite for OssWriter { } } -#[async_trait] impl oio::AppendWrite for OssWriter { async fn offset(&self) -> Result { let resp = self.core.oss_head_object(&self.path, None, None).await?; diff --git a/core/src/services/pcloud/writer.rs b/core/src/services/pcloud/writer.rs index 36c2a2f6440b..a19ba154c394 100644 --- a/core/src/services/pcloud/writer.rs +++ b/core/src/services/pcloud/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -40,7 +39,6 @@ impl PcloudWriter { } } -#[async_trait] impl oio::OneShotWrite for PcloudWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { self.core.ensure_dir_exists(&self.path).await?; diff --git a/core/src/services/s3/writer.rs b/core/src/services/s3/writer.rs index 9b6b1bb6cc25..0f6d8194566f 100644 --- a/core/src/services/s3/writer.rs +++ b/core/src/services/s3/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use http::StatusCode; use super::core::*; @@ -44,8 +43,6 @@ impl S3Writer { } } -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] impl oio::MultipartWrite for S3Writer { async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let mut req = self diff --git a/core/src/services/seafile/writer.rs b/core/src/services/seafile/writer.rs index c2cb8d54b4d0..dccba7283072 100644 --- a/core/src/services/seafile/writer.rs +++ b/core/src/services/seafile/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::header; use http::Request; @@ -46,7 +45,6 @@ impl SeafileWriter { } } -#[async_trait] impl oio::OneShotWrite for SeafileWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let upload_url = self.core.get_upload_url().await?; diff --git a/core/src/services/sftp/writer.rs b/core/src/services/sftp/writer.rs index 932f115bae8a..9bee111e2440 100644 --- a/core/src/services/sftp/writer.rs +++ b/core/src/services/sftp/writer.rs @@ -16,19 +16,17 @@ // under the License. use std::pin::Pin; -use std::task::Context; -use std::task::Poll; -use async_trait::async_trait; use bytes::Bytes; use openssh_sftp_client::file::File; use openssh_sftp_client::file::TokioCompatFile; -use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt; -use crate::raw::oio; +use crate::raw::{new_std_io_error, oio}; use crate::*; pub struct SftpWriter { + /// TODO: maybe we can use `File` directly? file: Pin>, } @@ -40,30 +38,19 @@ impl SftpWriter { } } -#[async_trait] impl oio::Write for SftpWriter { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: Bytes) -> Poll> { - self.file - .as_mut() - .poll_write(cx, &bs) - .map_err(new_std_io_error) + async fn write(&mut self, bs: Bytes) -> Result { + self.file.write(&bs).await.map_err(new_std_io_error) } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - self.file - .as_mut() - .poll_shutdown(cx) - .map_err(new_std_io_error) + async fn close(&mut self) -> Result<()> { + self.file.shutdown().await.map_err(new_std_io_error) } - fn poll_abort(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Err(Error::new( + async fn abort(&mut self) -> Result<()> { + Err(Error::new( ErrorKind::Unsupported, "SftpWriter doesn't support abort", - ))) + )) } } - -fn new_std_io_error(err: std::io::Error) -> Error { - Error::new(ErrorKind::Unexpected, "read from sftp").set_source(err) -} diff --git a/core/src/services/supabase/writer.rs b/core/src/services/supabase/writer.rs index a156c0936887..e5a753942eb5 100644 --- a/core/src/services/supabase/writer.rs +++ b/core/src/services/supabase/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -44,7 +43,6 @@ impl SupabaseWriter { } } -#[async_trait] impl oio::OneShotWrite for SupabaseWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let mut req = self.core.supabase_upload_object_request( diff --git a/core/src/services/swift/writer.rs b/core/src/services/swift/writer.rs index 046a01fa6901..6004580c7e5b 100644 --- a/core/src/services/swift/writer.rs +++ b/core/src/services/swift/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -38,7 +37,6 @@ impl SwiftWriter { } } -#[async_trait] impl oio::OneShotWrite for SwiftWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let resp = self diff --git a/core/src/services/upyun/writer.rs b/core/src/services/upyun/writer.rs index 973585f1fe8b..833fff34f3ba 100644 --- a/core/src/services/upyun/writer.rs +++ b/core/src/services/upyun/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use http::StatusCode; use super::core::constants::X_UPYUN_MULTI_UUID; @@ -40,7 +39,6 @@ impl UpyunWriter { } } -#[async_trait] impl oio::MultipartWrite for UpyunWriter { async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let req = self diff --git a/core/src/services/vercel_artifacts/writer.rs b/core/src/services/vercel_artifacts/writer.rs index 6cb2bea133ab..e48b8fe5bb78 100644 --- a/core/src/services/vercel_artifacts/writer.rs +++ b/core/src/services/vercel_artifacts/writer.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -41,7 +40,6 @@ impl VercelArtifactsWriter { } } -#[async_trait] impl oio::OneShotWrite for VercelArtifactsWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let resp = self diff --git a/core/src/services/vercel_blob/writer.rs b/core/src/services/vercel_blob/writer.rs index f7f3b90d759e..c3bcbf87cc63 100644 --- a/core/src/services/vercel_blob/writer.rs +++ b/core/src/services/vercel_blob/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use http::StatusCode; use super::core::InitiateMultipartUploadResponse; @@ -42,7 +41,6 @@ impl VercelBlobWriter { } } -#[async_trait] impl oio::MultipartWrite for VercelBlobWriter { async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let req = self diff --git a/core/src/services/webdav/writer.rs b/core/src/services/webdav/writer.rs index a5ca9cbaddac..525928adcd52 100644 --- a/core/src/services/webdav/writer.rs +++ b/core/src/services/webdav/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::StatusCode; @@ -40,7 +39,6 @@ impl WebdavWriter { } } -#[async_trait] impl oio::OneShotWrite for WebdavWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { let resp = self diff --git a/core/src/services/webhdfs/writer.rs b/core/src/services/webhdfs/writer.rs index 8cd935328fa3..02757a57afb9 100644 --- a/core/src/services/webhdfs/writer.rs +++ b/core/src/services/webhdfs/writer.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; use http::StatusCode; use uuid::Uuid; @@ -40,7 +39,6 @@ impl WebhdfsWriter { } } -#[async_trait] impl oio::BlockWrite for WebhdfsWriter { async fn write_once(&self, size: u64, body: AsyncBody) -> Result<()> { let req = self @@ -153,7 +151,6 @@ impl oio::BlockWrite for WebhdfsWriter { } } -#[async_trait] impl oio::AppendWrite for WebhdfsWriter { async fn offset(&self) -> Result { Ok(0) diff --git a/core/src/services/yandex_disk/writer.rs b/core/src/services/yandex_disk/writer.rs index 61df0e347e56..7495f59ce799 100644 --- a/core/src/services/yandex_disk/writer.rs +++ b/core/src/services/yandex_disk/writer.rs @@ -17,7 +17,6 @@ use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use http::Request; use http::StatusCode; @@ -40,7 +39,6 @@ impl YandexDiskWriter { } } -#[async_trait] impl oio::OneShotWrite for YandexDiskWriter { async fn write_once(&self, bs: Bytes) -> Result<()> { self.core.ensure_dir_exists(&self.path).await?; diff --git a/core/src/types/operator/operator.rs b/core/src/types/operator/operator.rs index 3d62471aeabe..a11d6b6d773d 100644 --- a/core/src/types/operator/operator.rs +++ b/core/src/types/operator/operator.rs @@ -18,7 +18,6 @@ use std::future::Future; use std::time::Duration; -use bytes::Buf; use bytes::Bytes; use futures::stream; use futures::Stream; @@ -27,7 +26,6 @@ use futures::TryStreamExt; use super::BlockingOperator; use crate::operator_futures::*; -use crate::raw::oio::WriteExt; use crate::raw::*; use crate::*; @@ -1258,7 +1256,7 @@ impl Operator { self.inner().clone(), path, (OpWrite::default(), bs), - |inner, path, (args, mut bs)| async move { + |inner, path, (args, bs)| async move { if !validate_path(&path, EntryMode::FILE) { return Err( Error::new(ErrorKind::IsADirectory, "write path is a directory") @@ -1268,12 +1266,9 @@ impl Operator { ); } - let (_, mut w) = inner.write(&path, args).await?; - while !bs.is_empty() { - let n = w.write(bs.clone()).await?; - bs.advance(n); - } - + let (_, w) = inner.write(&path, args).await?; + let mut w = Writer::new(w); + w.write(bs.clone()).await?; w.close().await?; Ok(()) diff --git a/core/src/types/writer.rs b/core/src/types/writer.rs index d133f108c28b..0048a1f862c9 100644 --- a/core/src/types/writer.rs +++ b/core/src/types/writer.rs @@ -16,17 +16,17 @@ // under the License. use std::io; +use std::pin::pin; use std::pin::Pin; +use std::task::ready; use std::task::Context; use std::task::Poll; +use bytes::Buf; use bytes::Bytes; -use futures::AsyncWrite; use futures::TryStreamExt; use crate::raw::oio::Write; -use crate::raw::oio::WriteBuf; -use crate::raw::oio::WriteExt; use crate::raw::*; use crate::*; @@ -72,10 +72,17 @@ use crate::*; /// - Services that doesn't support append will return [`ErrorKind::Unsupported`] error when /// creating writer with `append` enabled. pub struct Writer { - inner: oio::Writer, + state: State, } impl Writer { + /// Create a new writer from an `oio::Writer`. + pub(crate) fn new(w: oio::Writer) -> Self { + Writer { + state: State::Idle(Some(w)), + } + } + /// Create a new writer. /// /// Create will use internal information to decide the most suitable @@ -86,14 +93,20 @@ impl Writer { pub(crate) async fn create(acc: FusedAccessor, path: &str, op: OpWrite) -> Result { let (_, w) = acc.write(path, op).await?; - Ok(Writer { inner: w }) + Ok(Writer { + state: State::Idle(Some(w)), + }) } /// Write into inner writer. pub async fn write(&mut self, bs: impl Into) -> Result<()> { + let State::Idle(Some(w)) = &mut self.state else { + return Err(Error::new(ErrorKind::Unexpected, "writer must be valid")); + }; + let mut bs = bs.into(); while !bs.is_empty() { - let n = self.inner.write(bs.clone()).await?; + let n = w.write(bs.clone()).await?; bs.advance(n); } @@ -134,12 +147,16 @@ impl Writer { S: futures::Stream>, T: Into, { - let mut sink_from = Box::pin(sink_from); + let State::Idle(Some(w)) = &mut self.state else { + return Err(Error::new(ErrorKind::Unexpected, "writer must be valid")); + }; + + let mut sink_from = pin!(sink_from); let mut written = 0; while let Some(bs) = sink_from.try_next().await? { let mut bs = bs.into(); while !bs.is_empty() { - let n = self.inner.write(bs.clone()).await?; + let n = w.write(bs.clone()).await?; bs.advance(n); written += n as u64; } @@ -195,7 +212,11 @@ impl Writer { /// Abort should only be called when the writer is not closed or /// aborted, otherwise an unexpected error could be returned. pub async fn abort(&mut self) -> Result<()> { - self.inner.abort().await + let State::Idle(Some(w)) = &mut self.state else { + return Err(Error::new(ErrorKind::Unexpected, "writer must be valid")); + }; + + w.abort().await } /// Close the writer and make sure all data have been committed. @@ -205,19 +226,49 @@ impl Writer { /// Close should only be called when the writer is not closed or /// aborted, otherwise an unexpected error could be returned. pub async fn close(&mut self) -> Result<()> { - self.inner.close().await + let State::Idle(Some(w)) = &mut self.state else { + return Err(Error::new(ErrorKind::Unexpected, "writer must be valid")); + }; + + w.close().await } } -impl AsyncWrite for Writer { +enum State { + Idle(Option), + Writing(BoxedStaticFuture<(oio::Writer, Result)>), + Closing(BoxedStaticFuture<(oio::Writer, Result<()>)>), +} + +unsafe impl Sync for State {} + +impl futures::AsyncWrite for Writer { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.inner - .poll_write(cx, Bytes::copy_from_slice(buf)) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) + match &mut self.state { + State::Idle(w) => { + let mut w = w.take().expect("writer must be valid"); + let bs = Bytes::copy_from_slice(buf); + let fut = async move { + let res = w.write(bs).await; + (w, res) + }; + self.state = State::Writing(Box::pin(fut)); + self.poll_write(cx, buf) + } + State::Writing(fut) => { + let (w, res) = ready!(fut.as_mut().poll(cx)); + self.state = State::Idle(Some(w)); + Poll::Ready(res.map_err(format_std_io_error)) + } + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::Interrupted, + "another io operation is in progress", + ))), + } } /// Writer makes sure that every write is flushed. @@ -226,9 +277,26 @@ impl AsyncWrite for Writer { } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner - .poll_close(cx) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) + match &mut self.state { + State::Idle(w) => { + let mut w = w.take().expect("writer must be valid"); + let fut = async move { + let res = w.close().await; + (w, res) + }; + self.state = State::Closing(Box::pin(fut)); + self.poll_close(cx) + } + State::Closing(fut) => { + let (w, res) = ready!(fut.as_mut().poll(cx)); + self.state = State::Idle(Some(w)); + Poll::Ready(res.map_err(format_std_io_error)) + } + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::Interrupted, + "another io operation is in progress", + ))), + } } } @@ -238,9 +306,27 @@ impl tokio::io::AsyncWrite for Writer { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.inner - .poll_write(cx, Bytes::copy_from_slice(buf)) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) + match &mut self.state { + State::Idle(w) => { + let mut w = w.take().expect("writer must be valid"); + let bs = Bytes::copy_from_slice(buf); + let fut = async move { + let res = w.write(bs).await; + (w, res) + }; + self.state = State::Writing(Box::pin(fut)); + self.poll_write(cx, buf) + } + State::Writing(fut) => { + let (w, res) = ready!(fut.as_mut().poll(cx)); + self.state = State::Idle(Some(w)); + Poll::Ready(res.map_err(format_std_io_error)) + } + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::Interrupted, + "another io operation is in progress", + ))), + } } fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { @@ -248,9 +334,26 @@ impl tokio::io::AsyncWrite for Writer { } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner - .poll_close(cx) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) + match &mut self.state { + State::Idle(w) => { + let mut w = w.take().expect("writer must be valid"); + let fut = async move { + let res = w.close().await; + (w, res) + }; + self.state = State::Closing(Box::pin(fut)); + self.poll_shutdown(cx) + } + State::Closing(fut) => { + let (w, res) = ready!(fut.as_mut().poll(cx)); + self.state = State::Idle(Some(w)); + Poll::Ready(res.map_err(format_std_io_error)) + } + _ => Poll::Ready(Err(io::Error::new( + io::ErrorKind::Interrupted, + "another io operation is in progress", + ))), + } } }