From 3b8121eaa9e9628536093836dcc41119716afd9e Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Tue, 14 May 2024 22:31:42 +0800 Subject: [PATCH] feat: Extract FileRead and FileWrite trait (#364) * feat: Extract FileRead and FileWrie trait Signed-off-by: Xuanwo * Enable s3 services for tests Signed-off-by: Xuanwo * Fix sort Signed-off-by: Xuanwo * Add comment for io trait Signed-off-by: Xuanwo * Fix test for rest Signed-off-by: Xuanwo * Use try join Signed-off-by: Xuanwo --------- Signed-off-by: Xuanwo --- Cargo.toml | 12 +- crates/catalog/glue/Cargo.toml | 1 + crates/catalog/glue/src/catalog.rs | 15 +-- crates/catalog/hms/Cargo.toml | 1 + crates/catalog/hms/src/catalog.rs | 15 +-- crates/catalog/rest/Cargo.toml | 1 + crates/iceberg/src/arrow/reader.rs | 59 ++++++++- crates/iceberg/src/io.rs | 119 +++++++++++++++--- crates/iceberg/src/spec/manifest.rs | 10 +- crates/iceberg/src/spec/manifest_list.rs | 12 +- crates/iceberg/src/spec/snapshot.rs | 9 +- crates/iceberg/src/table.rs | 20 +-- .../src/writer/file_writer/parquet_writer.rs | 110 +++++++++++++++- .../src/writer/file_writer/track_writer.rs | 47 +++---- crates/iceberg/src/writer/mod.rs | 17 +-- crates/iceberg/tests/file_io_s3_test.rs | 15 +-- crates/integrations/datafusion/Cargo.toml | 1 + 17 files changed, 315 insertions(+), 149 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d1894c146..57c343611 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,11 +18,11 @@ [workspace] resolver = "2" members = [ - "crates/catalog/*", - "crates/examples", - "crates/iceberg", - "crates/integrations/*", - "crates/test_utils", + "crates/catalog/*", + "crates/examples", + "crates/iceberg", + "crates/integrations/*", + "crates/test_utils", ] [workspace.package] @@ -64,7 +64,7 @@ log = "^0.4" mockito = "^1" murmur3 = "0.5.2" once_cell = "1" -opendal = "0.45" +opendal = "0.46" ordered-float = "4.0.0" parquet = "51" pilota = "0.11.0" diff --git a/crates/catalog/glue/Cargo.toml b/crates/catalog/glue/Cargo.toml index 0508378e7..8e1c077f1 100644 --- a/crates/catalog/glue/Cargo.toml +++ b/crates/catalog/glue/Cargo.toml @@ -42,4 +42,5 @@ uuid = { workspace = true } [dev-dependencies] iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } +opendal = { workspace = true, features = ["services-s3"] } port_scanner = { workspace = true } diff --git a/crates/catalog/glue/src/catalog.rs b/crates/catalog/glue/src/catalog.rs index f40212950..147d86ac9 100644 --- a/crates/catalog/glue/src/catalog.rs +++ b/crates/catalog/glue/src/catalog.rs @@ -25,7 +25,6 @@ use iceberg::{ TableIdent, }; use std::{collections::HashMap, fmt::Debug}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; use typed_builder::TypedBuilder; @@ -358,13 +357,10 @@ impl Catalog for GlueCatalog { let metadata = TableMetadataBuilder::from_table_creation(creation)?.build()?; let metadata_location = create_metadata_location(&location, 0)?; - let mut file = self - .file_io + self.file_io .new_output(&metadata_location)? - .writer() + .write(serde_json::to_vec(&metadata)?.into()) .await?; - file.write_all(&serde_json::to_vec(&metadata)?).await?; - file.shutdown().await?; let glue_table = convert_to_glue_table( &table_name, @@ -431,10 +427,9 @@ impl Catalog for GlueCatalog { Some(table) => { let metadata_location = get_metadata_location(&table.parameters)?; - let mut reader = self.file_io.new_input(&metadata_location)?.reader().await?; - let mut metadata_str = String::new(); - reader.read_to_string(&mut metadata_str).await?; - let metadata = serde_json::from_str::(&metadata_str)?; + let input_file = self.file_io.new_input(&metadata_location)?; + let metadata_content = input_file.read().await?; + let metadata = serde_json::from_slice::(&metadata_content)?; let table = Table::builder() .file_io(self.file_io()) diff --git a/crates/catalog/hms/Cargo.toml b/crates/catalog/hms/Cargo.toml index 5a032215a..b53901552 100644 --- a/crates/catalog/hms/Cargo.toml +++ b/crates/catalog/hms/Cargo.toml @@ -44,4 +44,5 @@ volo-thrift = { workspace = true } [dev-dependencies] iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } +opendal = { workspace = true, features = ["services-s3"] } port_scanner = { workspace = true } diff --git a/crates/catalog/hms/src/catalog.rs b/crates/catalog/hms/src/catalog.rs index 2f545dd03..18fcacdfc 100644 --- a/crates/catalog/hms/src/catalog.rs +++ b/crates/catalog/hms/src/catalog.rs @@ -35,8 +35,6 @@ use iceberg::{ use std::collections::HashMap; use std::fmt::{Debug, Formatter}; use std::net::ToSocketAddrs; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; use typed_builder::TypedBuilder; use volo_thrift::ResponseError; @@ -349,13 +347,10 @@ impl Catalog for HmsCatalog { let metadata = TableMetadataBuilder::from_table_creation(creation)?.build()?; let metadata_location = create_metadata_location(&location, 0)?; - let mut file = self - .file_io + self.file_io .new_output(&metadata_location)? - .writer() + .write(serde_json::to_vec(&metadata)?.into()) .await?; - file.write_all(&serde_json::to_vec(&metadata)?).await?; - file.shutdown().await?; let hive_table = convert_to_hive_table( db_name.clone(), @@ -406,10 +401,8 @@ impl Catalog for HmsCatalog { let metadata_location = get_metadata_location(&hive_table.parameters)?; - let mut reader = self.file_io.new_input(&metadata_location)?.reader().await?; - let mut metadata_str = String::new(); - reader.read_to_string(&mut metadata_str).await?; - let metadata = serde_json::from_str::(&metadata_str)?; + let metadata_content = self.file_io.new_input(&metadata_location)?.read().await?; + let metadata = serde_json::from_slice::(&metadata_content)?; let table = Table::builder() .file_io(self.file_io()) diff --git a/crates/catalog/rest/Cargo.toml b/crates/catalog/rest/Cargo.toml index 7abe9c8e3..43e589910 100644 --- a/crates/catalog/rest/Cargo.toml +++ b/crates/catalog/rest/Cargo.toml @@ -46,5 +46,6 @@ uuid = { workspace = true, features = ["v4"] } [dev-dependencies] iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } mockito = { workspace = true } +opendal = { workspace = true, features = ["services-fs"] } port_scanner = { workspace = true } tokio = { workspace = true } diff --git a/crates/iceberg/src/arrow/reader.rs b/crates/iceberg/src/arrow/reader.rs index e3f30f8d9..fe5efaca1 100644 --- a/crates/iceberg/src/arrow/reader.rs +++ b/crates/iceberg/src/arrow/reader.rs @@ -19,14 +19,21 @@ use arrow_schema::SchemaRef as ArrowSchemaRef; use async_stream::try_stream; +use bytes::Bytes; +use futures::future::BoxFuture; use futures::stream::StreamExt; +use futures::{try_join, TryFutureExt}; +use parquet::arrow::async_reader::{AsyncFileReader, MetadataLoader}; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask, PARQUET_FIELD_ID_META_KEY}; +use parquet::file::metadata::ParquetMetaData; use parquet::schema::types::SchemaDescriptor; use std::collections::HashMap; +use std::ops::Range; use std::str::FromStr; +use std::sync::Arc; use crate::arrow::arrow_schema_to_schema; -use crate::io::FileIO; +use crate::io::{FileIO, FileMetadata, FileRead}; use crate::scan::{ArrowRecordBatchStream, FileScanTaskStream}; use crate::spec::SchemaRef; use crate::{Error, ErrorKind}; @@ -91,12 +98,12 @@ impl ArrowReader { Ok(try_stream! { while let Some(Ok(task)) = tasks.next().await { - let parquet_reader = file_io - .new_input(task.data().data_file().file_path())? - .reader() - .await?; + let parquet_file = file_io + .new_input(task.data().data_file().file_path())?; + let (parquet_metadata, parquet_reader) = try_join!(parquet_file.metadata(), parquet_file.reader())?; + let arrow_file_reader = ArrowFileReader::new(parquet_metadata, parquet_reader); - let mut batch_stream_builder = ParquetRecordBatchStreamBuilder::new(parquet_reader) + let mut batch_stream_builder = ParquetRecordBatchStreamBuilder::new(arrow_file_reader) .await?; let parquet_schema = batch_stream_builder.parquet_schema(); @@ -187,3 +194,43 @@ impl ArrowReader { } } } + +/// ArrowFileReader is a wrapper around a FileRead that impls parquets AsyncFileReader. +/// +/// # TODO +/// +/// [ParquetObjectReader](https://docs.rs/parquet/latest/src/parquet/arrow/async_reader/store.rs.html#64) contains the following hints to speed up metadata loading, we can consider adding them to this struct: +/// +/// - `metadata_size_hint`: Provide a hint as to the size of the parquet file's footer. +/// - `preload_column_index`: Load the Column Index as part of [`Self::get_metadata`]. +/// - `preload_offset_index`: Load the Offset Index as part of [`Self::get_metadata`]. +struct ArrowFileReader { + meta: FileMetadata, + r: R, +} + +impl ArrowFileReader { + /// Create a new ArrowFileReader + fn new(meta: FileMetadata, r: R) -> Self { + Self { meta, r } + } +} + +impl AsyncFileReader for ArrowFileReader { + fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, parquet::errors::Result> { + Box::pin( + self.r + .read(range.start as _..range.end as _) + .map_err(|err| parquet::errors::ParquetError::External(Box::new(err))), + ) + } + + fn get_metadata(&mut self) -> BoxFuture<'_, parquet::errors::Result>> { + Box::pin(async move { + let file_size = self.meta.size; + let mut loader = MetadataLoader::load(self, file_size as usize, None).await?; + loader.load_page_index(false, false).await?; + Ok(Arc::new(loader.finish())) + }) + } +} diff --git a/crates/iceberg/src/io.rs b/crates/iceberg/src/io.rs index d3f07cb64..c045b22f1 100644 --- a/crates/iceberg/src/io.rs +++ b/crates/iceberg/src/io.rs @@ -48,14 +48,13 @@ //! - `new_input`: Create input file for reading. //! - `new_output`: Create output file for writing. +use bytes::Bytes; +use std::ops::Range; use std::{collections::HashMap, sync::Arc}; use crate::{error::Result, Error, ErrorKind}; -use futures::{AsyncRead, AsyncSeek, AsyncWrite}; use once_cell::sync::Lazy; use opendal::{Operator, Scheme}; -use tokio::io::AsyncWrite as TokioAsyncWrite; -use tokio::io::{AsyncRead as TokioAsyncRead, AsyncSeek as TokioAsyncSeek}; use url::Url; /// Following are arguments for [s3 file io](https://py.iceberg.apache.org/configuration/#s3). @@ -206,6 +205,35 @@ impl FileIO { } } +/// The struct the represents the metadata of a file. +/// +/// TODO: we can add last modified time, content type, etc. in the future. +pub struct FileMetadata { + /// The size of the file. + pub size: u64, +} + +/// Trait for reading file. +/// +/// # TODO +/// +/// It's possible for us to remove the async_trait, but we need to figure +/// out how to handle the object safety. +#[async_trait::async_trait] +pub trait FileRead: Send + Unpin + 'static { + /// Read file content with given range. + /// + /// TODO: we can support reading non-contiguous bytes in the future. + async fn read(&self, range: Range) -> Result; +} + +#[async_trait::async_trait] +impl FileRead for opendal::Reader { + async fn read(&self, range: Range) -> Result { + Ok(opendal::Reader::read(self, range).await?.to_bytes()) + } +} + /// Input file is used for reading from files. #[derive(Debug)] pub struct InputFile { @@ -216,14 +244,6 @@ pub struct InputFile { relative_path_pos: usize, } -/// Trait for reading file. -pub trait FileRead: AsyncRead + AsyncSeek + Send + Unpin + TokioAsyncRead + TokioAsyncSeek {} - -impl FileRead for T where - T: AsyncRead + AsyncSeek + Send + Unpin + TokioAsyncRead + TokioAsyncSeek -{ -} - impl InputFile { /// Absolute path to root uri. pub fn location(&self) -> &str { @@ -238,16 +258,63 @@ impl InputFile { .await?) } - /// Creates [`InputStream`] for reading. + /// Fetch and returns metadata of file. + pub async fn metadata(&self) -> Result { + let meta = self.op.stat(&self.path[self.relative_path_pos..]).await?; + + Ok(FileMetadata { + size: meta.content_length(), + }) + } + + /// Read and returns whole content of file. + /// + /// For continues reading, use [`Self::reader`] instead. + pub async fn read(&self) -> Result { + Ok(self + .op + .read(&self.path[self.relative_path_pos..]) + .await? + .to_bytes()) + } + + /// Creates [`FileRead`] for continues reading. + /// + /// For one-time reading, use [`Self::read`] instead. pub async fn reader(&self) -> Result { Ok(self.op.reader(&self.path[self.relative_path_pos..]).await?) } } /// Trait for writing file. -pub trait FileWrite: AsyncWrite + TokioAsyncWrite + Send + Unpin {} +/// +/// # TODO +/// +/// It's possible for us to remove the async_trait, but we need to figure +/// out how to handle the object safety. +#[async_trait::async_trait] +pub trait FileWrite: Send + Unpin + 'static { + /// Write bytes to file. + /// + /// TODO: we can support writing non-contiguous bytes in the future. + async fn write(&mut self, bs: Bytes) -> Result<()>; -impl FileWrite for T where T: AsyncWrite + TokioAsyncWrite + Send + Unpin {} + /// Close file. + /// + /// Calling close on closed file will generate an error. + async fn close(&mut self) -> Result<()>; +} + +#[async_trait::async_trait] +impl FileWrite for opendal::Writer { + async fn write(&mut self, bs: Bytes) -> Result<()> { + Ok(opendal::Writer::write(self, bs).await?) + } + + async fn close(&mut self) -> Result<()> { + Ok(opendal::Writer::close(self).await?) + } +} /// Output file is used for writing to files.. #[derive(Debug)] @@ -282,7 +349,23 @@ impl OutputFile { } } - /// Creates output file for writing. + /// Create a new output file with given bytes. + /// + /// # Notes + /// + /// Calling `write` will overwrite the file if it exists. + /// For continues writing, use [`Self::writer`]. + pub async fn write(&self, bs: Bytes) -> Result<()> { + let mut writer = self.writer().await?; + writer.write(bs).await?; + writer.close().await + } + + /// Creates output file for continues writing. + /// + /// # Notes + /// + /// For one-time writing, use [`Self::write`] instead. pub async fn writer(&self) -> Result> { Ok(Box::new( self.op.writer(&self.path[self.relative_path_pos..]).await?, @@ -398,7 +481,7 @@ mod tests { use std::{fs::File, path::Path}; use futures::io::AllowStdIo; - use futures::{AsyncReadExt, AsyncWriteExt}; + use futures::AsyncReadExt; use tempfile::TempDir; @@ -483,9 +566,7 @@ mod tests { assert!(!output_file.exists().await.unwrap()); { - let mut writer = output_file.writer().await.unwrap(); - writer.write_all(content.as_bytes()).await.unwrap(); - writer.close().await.unwrap(); + output_file.write(content.into()).await.unwrap(); } assert_eq!(&full_path, output_file.location()); diff --git a/crates/iceberg/src/spec/manifest.rs b/crates/iceberg/src/spec/manifest.rs index 3daa5c288..b1eb21653 100644 --- a/crates/iceberg/src/spec/manifest.rs +++ b/crates/iceberg/src/spec/manifest.rs @@ -28,7 +28,7 @@ use crate::io::OutputFile; use crate::spec::PartitionField; use crate::{Error, ErrorKind}; use apache_avro::{from_value, to_value, Reader as AvroReader, Writer as AvroWriter}; -use futures::AsyncWriteExt; +use bytes::Bytes; use serde_json::to_vec; use std::cmp::min; use std::collections::HashMap; @@ -291,13 +291,7 @@ impl ManifestWriter { let length = avro_writer.flush()?; let content = avro_writer.into_inner()?; - let mut writer = self.output.writer().await?; - writer.write_all(&content).await.map_err(|err| { - Error::new(ErrorKind::Unexpected, "Fail to write Manifest Entry").with_source(err) - })?; - writer.close().await.map_err(|err| { - Error::new(ErrorKind::Unexpected, "Fail to write Manifest Entry").with_source(err) - })?; + self.output.write(Bytes::from(content)).await?; let partition_summary = self.get_field_summary_vec(&manifest.metadata.partition_spec.fields); diff --git a/crates/iceberg/src/spec/manifest_list.rs b/crates/iceberg/src/spec/manifest_list.rs index c390bee04..26a4acc60 100644 --- a/crates/iceberg/src/spec/manifest_list.rs +++ b/crates/iceberg/src/spec/manifest_list.rs @@ -22,7 +22,7 @@ use std::{collections::HashMap, str::FromStr}; use crate::io::FileIO; use crate::{io::OutputFile, spec::Literal, Error, ErrorKind}; use apache_avro::{from_value, types::Value, Reader, Writer}; -use futures::{AsyncReadExt, AsyncWriteExt}; +use bytes::Bytes; use self::{ _const_schema::{MANIFEST_LIST_AVRO_SCHEMA_V1, MANIFEST_LIST_AVRO_SCHEMA_V2}, @@ -212,7 +212,7 @@ impl ManifestListWriter { pub async fn close(self) -> Result<()> { let data = self.avro_writer.into_inner()?; let mut writer = self.output_file.writer().await?; - writer.write_all(&data).await?; + writer.write(Bytes::from(data)).await?; writer.close().await?; Ok(()) } @@ -632,13 +632,7 @@ impl ManifestFile { /// /// This method will also initialize inherited values of [`ManifestEntry`], such as `sequence_number`. pub async fn load_manifest(&self, file_io: &FileIO) -> Result { - let mut avro = Vec::new(); - file_io - .new_input(&self.manifest_path)? - .reader() - .await? - .read_to_end(&mut avro) - .await?; + let avro = file_io.new_input(&self.manifest_path)?.read().await?; let (metadata, mut entries) = Manifest::try_from_avro_bytes(&avro)?; diff --git a/crates/iceberg/src/spec/snapshot.rs b/crates/iceberg/src/spec/snapshot.rs index 3b4558bb6..53eee6bf6 100644 --- a/crates/iceberg/src/spec/snapshot.rs +++ b/crates/iceberg/src/spec/snapshot.rs @@ -20,7 +20,6 @@ */ use crate::error::Result; use chrono::{DateTime, TimeZone, Utc}; -use futures::AsyncReadExt; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; @@ -166,13 +165,7 @@ impl Snapshot { file_io: &FileIO, table_metadata: &TableMetadata, ) -> Result { - let mut manifest_list_content = Vec::new(); - file_io - .new_input(&self.manifest_list)? - .reader() - .await? - .read_to_end(&mut manifest_list_content) - .await?; + let manifest_list_content = file_io.new_input(&self.manifest_list)?.read().await?; let schema = self.schema(table_metadata)?; diff --git a/crates/iceberg/src/table.rs b/crates/iceberg/src/table.rs index f38d77131..fd8bd28f2 100644 --- a/crates/iceberg/src/table.rs +++ b/crates/iceberg/src/table.rs @@ -21,7 +21,6 @@ use crate::scan::TableScanBuilder; use crate::spec::{TableMetadata, TableMetadataRef}; use crate::Result; use crate::TableIdent; -use futures::AsyncReadExt; use typed_builder::TypedBuilder; /// Table represents a table in the catalog. @@ -118,12 +117,8 @@ impl StaticTable { file_io: FileIO, ) -> Result { let metadata_file = file_io.new_input(metadata_file_path)?; - let mut metadata_file_reader = metadata_file.reader().await?; - let mut metadata_file_content = String::new(); - metadata_file_reader - .read_to_string(&mut metadata_file_content) - .await?; - let table_metadata = serde_json::from_str::(&metadata_file_content)?; + let metadata_file_content = metadata_file.read().await?; + let table_metadata = serde_json::from_slice::(&metadata_file_content)?; Self::from_metadata(table_metadata, table_ident, file_io).await } @@ -148,6 +143,7 @@ impl StaticTable { #[cfg(test)] mod tests { use super::*; + #[tokio::test] async fn test_static_table_from_file() { let metadata_file_name = "TableMetadataV2Valid.json"; @@ -211,13 +207,9 @@ mod tests { .build() .unwrap(); let metadata_file = file_io.new_input(metadata_file_path).unwrap(); - let mut metadata_file_reader = metadata_file.reader().await.unwrap(); - let mut metadata_file_content = String::new(); - metadata_file_reader - .read_to_string(&mut metadata_file_content) - .await - .unwrap(); - let table_metadata = serde_json::from_str::(&metadata_file_content).unwrap(); + let metadata_file_content = metadata_file.read().await.unwrap(); + let table_metadata = + serde_json::from_slice::(&metadata_file_content).unwrap(); let static_identifier = TableIdent::from_strs(["ns", "table"]).unwrap(); let table = Table::builder() .metadata(table_metadata) diff --git a/crates/iceberg/src/writer/file_writer/parquet_writer.rs b/crates/iceberg/src/writer/file_writer/parquet_writer.rs index b743d8435..a67d308af 100644 --- a/crates/iceberg/src/writer/file_writer/parquet_writer.rs +++ b/crates/iceberg/src/writer/file_writer/parquet_writer.rs @@ -17,12 +17,14 @@ //! The module contains the file writer for parquet file format. +use std::pin::Pin; +use std::task::{Context, Poll}; use std::{ collections::HashMap, sync::{atomic::AtomicI64, Arc}, }; -use crate::{io::FileIO, Result}; +use crate::{io::FileIO, io::FileWrite, Result}; use crate::{ io::OutputFile, spec::{DataFileBuilder, DataFileFormat}, @@ -30,6 +32,8 @@ use crate::{ Error, }; use arrow_schema::SchemaRef as ArrowSchemaRef; +use bytes::Bytes; +use futures::future::BoxFuture; use parquet::{arrow::AsyncArrowWriter, format::FileMetaData}; use parquet::{arrow::PARQUET_FIELD_ID_META_KEY, file::properties::WriterProperties}; @@ -103,7 +107,8 @@ impl FileWriterBuilder for ParquetWr .generate_location(&self.file_name_generator.generate_file_name()), )?; let inner_writer = TrackWriter::new(out_file.writer().await?, written_size.clone()); - let writer = AsyncArrowWriter::try_new(inner_writer, self.schema.clone(), Some(self.props)) + let async_writer = AsyncFileWriter::new(inner_writer); + let writer = AsyncArrowWriter::try_new(async_writer, self.schema.clone(), Some(self.props)) .map_err(|err| { Error::new( crate::ErrorKind::Unexpected, @@ -125,7 +130,7 @@ impl FileWriterBuilder for ParquetWr /// `ParquetWriter`` is used to write arrow data into parquet file on storage. pub struct ParquetWriter { out_file: OutputFile, - writer: AsyncArrowWriter, + writer: AsyncArrowWriter>, written_size: Arc, current_row_num: usize, field_ids: Vec, @@ -246,6 +251,105 @@ impl CurrentFileStatus for ParquetWriter { } } +/// AsyncFileWriter is a wrapper of FileWrite to make it compatible with tokio::io::AsyncWrite. +/// +/// # NOTES +/// +/// We keep this wrapper been used inside only. +/// +/// # TODO +/// +/// Maybe we can use the buffer from ArrowWriter directly. +struct AsyncFileWriter(State); + +enum State { + Idle(Option), + Write(BoxFuture<'static, (W, Result<()>)>), + Close(BoxFuture<'static, (W, Result<()>)>), +} + +impl AsyncFileWriter { + /// Create a new `AsyncFileWriter` with the given writer. + pub fn new(writer: W) -> Self { + Self(State::Idle(Some(writer))) + } +} + +impl tokio::io::AsyncWrite for AsyncFileWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + loop { + match &mut this.0 { + State::Idle(w) => { + let mut writer = w.take().unwrap(); + let bs = Bytes::copy_from_slice(buf); + let fut = async move { + let res = writer.write(bs).await; + (writer, res) + }; + this.0 = State::Write(Box::pin(fut)); + } + State::Write(fut) => { + let (writer, res) = futures::ready!(fut.as_mut().poll(cx)); + this.0 = State::Idle(Some(writer)); + return Poll::Ready(res.map(|_| buf.len()).map_err(|err| { + std::io::Error::new(std::io::ErrorKind::Other, Box::new(err)) + })); + } + State::Close(_) => { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "file is closed", + ))); + } + } + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + loop { + match &mut this.0 { + State::Idle(w) => { + let mut writer = w.take().unwrap(); + let fut = async move { + let res = writer.close().await; + (writer, res) + }; + this.0 = State::Close(Box::pin(fut)); + } + State::Write(_) => { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "file is writing", + ))); + } + State::Close(fut) => { + let (writer, res) = futures::ready!(fut.as_mut().poll(cx)); + this.0 = State::Idle(Some(writer)); + return Poll::Ready(res.map_err(|err| { + std::io::Error::new(std::io::ErrorKind::Other, Box::new(err)) + })); + } + } + } + } +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/crates/iceberg/src/writer/file_writer/track_writer.rs b/crates/iceberg/src/writer/file_writer/track_writer.rs index 938addd08..8d0e490d4 100644 --- a/crates/iceberg/src/writer/file_writer/track_writer.rs +++ b/crates/iceberg/src/writer/file_writer/track_writer.rs @@ -15,14 +15,11 @@ // specific language governing permissions and limitations // under the License. -use std::{ - pin::Pin, - sync::{atomic::AtomicI64, Arc}, -}; - -use tokio::io::AsyncWrite; +use bytes::Bytes; +use std::sync::{atomic::AtomicI64, Arc}; use crate::io::FileWrite; +use crate::Result; /// `TrackWriter` is used to track the written size. pub(crate) struct TrackWriter { @@ -39,34 +36,18 @@ impl TrackWriter { } } -impl AsyncWrite for TrackWriter { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - match Pin::new(&mut self.inner).poll_write(cx, buf) { - std::task::Poll::Ready(Ok(n)) => { - self.written_size - .fetch_add(buf.len() as i64, std::sync::atomic::Ordering::Relaxed); - std::task::Poll::Ready(Ok(n)) - } - std::task::Poll::Ready(Err(e)) => std::task::Poll::Ready(Err(e)), - std::task::Poll::Pending => std::task::Poll::Pending, - } - } - - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Pin::new(&mut self.inner).poll_flush(cx) +#[async_trait::async_trait] +impl FileWrite for TrackWriter { + async fn write(&mut self, bs: Bytes) -> Result<()> { + let size = bs.len(); + self.inner.write(bs).await.map(|v| { + self.written_size + .fetch_add(size as i64, std::sync::atomic::Ordering::Relaxed); + v + }) } - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) + async fn close(&mut self) -> Result<()> { + self.inner.close().await } } diff --git a/crates/iceberg/src/writer/mod.rs b/crates/iceberg/src/writer/mod.rs index 7618d2ec3..216e94f9e 100644 --- a/crates/iceberg/src/writer/mod.rs +++ b/crates/iceberg/src/writer/mod.rs @@ -95,8 +95,6 @@ mod tests { use arrow_array::RecordBatch; use arrow_schema::Schema; use arrow_select::concat::concat_batches; - use bytes::Bytes; - use futures::AsyncReadExt; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use crate::{ @@ -124,16 +122,11 @@ mod tests { ) { assert_eq!(data_file.file_format, DataFileFormat::Parquet); + let input_file = file_io.new_input(data_file.file_path.clone()).unwrap(); // read the written file - let mut input_file = file_io - .new_input(data_file.file_path.clone()) - .unwrap() - .reader() - .await - .unwrap(); - let mut res = vec![]; - let file_size = input_file.read_to_end(&mut res).await.unwrap(); - let reader_builder = ParquetRecordBatchReaderBuilder::try_new(Bytes::from(res)).unwrap(); + let input_content = input_file.read().await.unwrap(); + let reader_builder = + ParquetRecordBatchReaderBuilder::try_new(input_content.clone()).unwrap(); let metadata = reader_builder.metadata().clone(); // check data @@ -154,7 +147,7 @@ mod tests { .sum::() as u64 ); - assert_eq!(data_file.file_size_in_bytes, file_size as u64); + assert_eq!(data_file.file_size_in_bytes, input_content.len() as u64); assert_eq!(data_file.column_sizes.len(), expect_column_num); data_file.column_sizes.iter().for_each(|(&k, &v)| { diff --git a/crates/iceberg/tests/file_io_s3_test.rs b/crates/iceberg/tests/file_io_s3_test.rs index 7553bcd2e..36e24f153 100644 --- a/crates/iceberg/tests/file_io_s3_test.rs +++ b/crates/iceberg/tests/file_io_s3_test.rs @@ -17,7 +17,6 @@ //! Integration tests for FileIO S3. -use futures::{AsyncReadExt, AsyncWriteExt}; use iceberg::io::{ FileIO, FileIOBuilder, S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION, S3_SECRET_ACCESS_KEY, }; @@ -74,9 +73,7 @@ async fn test_file_io_s3_output() { .new_output("s3://bucket1/test_output") .unwrap(); { - let mut writer = output_file.writer().await.unwrap(); - writer.write_all("123".as_bytes()).await.unwrap(); - writer.close().await.unwrap(); + output_file.write("123".into()).await.unwrap(); } assert!(fixture .file_io @@ -93,18 +90,16 @@ async fn test_file_io_s3_input() { .new_output("s3://bucket1/test_input") .unwrap(); { - let mut writer = output_file.writer().await.unwrap(); - writer.write_all("test_input".as_bytes()).await.unwrap(); - writer.close().await.unwrap(); + output_file.write("test_input".into()).await.unwrap(); } + let input_file = fixture .file_io .new_input("s3://bucket1/test_input") .unwrap(); + { - let mut reader = input_file.reader().await.unwrap(); - let mut buffer = vec![]; - reader.read_to_end(&mut buffer).await.unwrap(); + let buffer = input_file.read().await.unwrap(); assert_eq!(buffer, "test_input".as_bytes()); } } diff --git a/crates/integrations/datafusion/Cargo.toml b/crates/integrations/datafusion/Cargo.toml index 9f895ab33..4e01723e1 100644 --- a/crates/integrations/datafusion/Cargo.toml +++ b/crates/integrations/datafusion/Cargo.toml @@ -40,4 +40,5 @@ tokio = { workspace = true } [dev-dependencies] iceberg-catalog-hms = { workspace = true } iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } +opendal = { workspace = true, features = ["services-s3"] } port_scanner = { workspace = true }