From b6eee0b5169541b1f5e61f1fc45e76d609360597 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Tue, 26 Nov 2024 22:10:14 -0800 Subject: [PATCH] =?UTF-8?q?[FEAT]=20migrate=20schema=20inference=20?= =?UTF-8?q?=E2=86=92=20async,=20block=20at=20py=20boundary=20(#3432)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Converts schema inference operations for CSV, JSON, and Parquet files to use async/await instead of synchronous runtime blocking. This architectural change ensures that blocking operations happen at the highest level possible (the Python API boundary) rather than deep within the inference logic. Key changes include: - Making read_csv_schema, read_json_schema, and read_parquet_schema async - Updating scan builder interfaces to use async finish() methods - Removing unnecessary runtime.block_on calls from schema inference paths - Moving runtime.block_on calls to Python API layer where blocking is unavoidable - Converting schema-related tests to use tokio async runtime - Adding common-runtime dependency where needed - Fixes #3423 This change improves the consistency of async IO handling and creates a cleaner architecture where blocking is consolidated at the Python interface rather than scattered throughout the codebase. --- Cargo.lock | 2 + Cargo.toml | 2 +- src/daft-csv/src/metadata.rs | 145 +++++++++--------- src/daft-csv/src/python.rs | 21 ++- src/daft-json/src/python.rs | 21 ++- src/daft-json/src/schema.rs | 56 +++---- src/daft-parquet/src/python.rs | 16 +- src/daft-parquet/src/read.rs | 8 +- src/daft-scan/Cargo.toml | 3 + src/daft-scan/src/builder.rs | 54 ++++--- src/daft-scan/src/glob.rs | 61 ++++---- src/daft-scan/src/lib.rs | 9 +- src/daft-scan/src/python.rs | 10 +- src/daft-sql/Cargo.toml | 1 + src/daft-sql/src/table_provider/read_csv.rs | 4 +- .../src/table_provider/read_parquet.rs | 5 +- 16 files changed, 231 insertions(+), 187 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index deb83e187b..fd22dcfa10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2415,6 +2415,7 @@ dependencies = [ "pyo3", "serde", "snafu", + "tokio", "typetag", "urlencoding", ] @@ -2479,6 +2480,7 @@ dependencies = [ "common-daft-config", "common-error", "common-io-config", + "common-runtime", "daft-core", "daft-dsl", "daft-functions", diff --git a/Cargo.toml b/Cargo.toml index bf8ef7bdea..67334d8b0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -194,11 +194,11 @@ chrono-tz = "0.10.0" comfy-table = "7.1.1" common-daft-config = {path = "src/common/daft-config"} common-error = {path = "src/common/error", default-features = false} +common-runtime = {path = "src/common/runtime", default-features = false} daft-core = {path = "src/daft-core"} daft-dsl = {path = "src/daft-dsl"} daft-hash = {path = "src/daft-hash"} daft-local-execution = {path = "src/daft-local-execution"} -daft-local-plan = {path = "src/daft-local-plan"} daft-logical-plan = {path = "src/daft-logical-plan"} daft-scan = {path = "src/daft-scan"} daft-schema = {path = "src/daft-schema"} diff --git a/src/daft-csv/src/metadata.rs b/src/daft-csv/src/metadata.rs index 17ce4c0267..b04595ecb8 100644 --- a/src/daft-csv/src/metadata.rs +++ b/src/daft-csv/src/metadata.rs @@ -3,7 +3,6 @@ use std::{collections::HashSet, sync::Arc}; use arrow2::io::csv::read_async::{AsyncReader, AsyncReaderBuilder}; use async_compat::CompatExt; use common_error::DaftResult; -use common_runtime::get_io_runtime; use csv_async::ByteRecord; use daft_compression::CompressionCodec; use daft_core::prelude::Schema; @@ -52,25 +51,22 @@ impl Default for CsvReadStats { } } -pub fn read_csv_schema( +pub async fn read_csv_schema( uri: &str, parse_options: Option, max_bytes: Option, io_client: Arc, io_stats: Option, ) -> DaftResult<(Schema, CsvReadStats)> { - let runtime_handle = get_io_runtime(true); - runtime_handle.block_on_current_thread(async { - read_csv_schema_single( - uri, - parse_options.unwrap_or_default(), - // Default to 1 MiB. - max_bytes.or(Some(1024 * 1024)), - io_client, - io_stats, - ) - .await - }) + read_csv_schema_single( + uri, + parse_options.unwrap_or_default(), + // Default to 1 MiB. + max_bytes.or(Some(1024 * 1024)), + io_client, + io_stats, + ) + .await } pub async fn read_csv_schema_bulk( @@ -81,32 +77,32 @@ pub async fn read_csv_schema_bulk( io_stats: Option, num_parallel_tasks: usize, ) -> DaftResult> { - let runtime_handle = get_io_runtime(true); - let result = runtime_handle - .block_on_current_thread(async { - let task_stream = futures::stream::iter(uris.iter().map(|uri| { - let owned_string = (*uri).to_string(); - let owned_client = io_client.clone(); - let owned_io_stats = io_stats.clone(); - let owned_parse_options = parse_options.clone(); - tokio::spawn(async move { - read_csv_schema_single( - &owned_string, - owned_parse_options.unwrap_or_default(), - max_bytes, - owned_client, - owned_io_stats, - ) - .await - }) - })); - task_stream - .buffered(num_parallel_tasks) - .try_collect::>() + let result = async { + let task_stream = futures::stream::iter(uris.iter().map(|uri| { + let owned_string = (*uri).to_string(); + let owned_client = io_client.clone(); + let owned_io_stats = io_stats.clone(); + let owned_parse_options = parse_options.clone(); + tokio::spawn(async move { + read_csv_schema_single( + &owned_string, + owned_parse_options.unwrap_or_default(), + max_bytes, + owned_client, + owned_io_stats, + ) .await - }) - .context(super::JoinSnafu {})?; - result.into_iter().collect::>>() + }) + })); + task_stream + .buffered(num_parallel_tasks) + .try_collect::>() + .await + } + .await + .context(super::JoinSnafu {})?; + + result.into_iter().collect() } pub(crate) async fn read_csv_schema_single( @@ -300,7 +296,8 @@ mod tests { use crate::CsvParseOptions; #[rstest] - fn test_csv_schema_local( + #[tokio::test] + async fn test_csv_schema_local( #[values( // Uncompressed None, @@ -333,7 +330,8 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; + let (schema, read_stats) = + read_csv_schema(file.as_ref(), None, None, io_client, None).await?; assert_eq!( schema, Schema::new(vec![ @@ -350,8 +348,8 @@ mod tests { Ok(()) } - #[test] - fn test_csv_schema_local_delimiter() -> DaftResult<()> { + #[tokio::test] + async fn test_csv_schema_local_delimiter() -> DaftResult<()> { let file = format!( "{}/test/iris_tiny_bar_delimiter.csv", env!("CARGO_MANIFEST_DIR"), @@ -367,7 +365,8 @@ mod tests { None, io_client, None, - )?; + ) + .await?; assert_eq!( schema, Schema::new(vec![ @@ -384,23 +383,23 @@ mod tests { Ok(()) } - #[test] - fn test_csv_schema_local_read_stats() -> DaftResult<()> { + #[tokio::test] + async fn test_csv_schema_local_read_stats() -> DaftResult<()> { let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (_, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; + let (_, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None).await?; assert_eq!(read_stats.total_bytes_read, 328); assert_eq!(read_stats.total_records_read, 20); Ok(()) } - #[test] - fn test_csv_schema_local_no_headers() -> DaftResult<()> { + #[tokio::test] + async fn test_csv_schema_local_no_headers() -> DaftResult<()> { let file = format!( "{}/test/iris_tiny_no_headers.csv", env!("CARGO_MANIFEST_DIR"), @@ -416,7 +415,8 @@ mod tests { None, io_client, None, - )?; + ) + .await?; assert_eq!( schema, Schema::new(vec![ @@ -433,8 +433,8 @@ mod tests { Ok(()) } - #[test] - fn test_csv_schema_local_empty_lines_skipped() -> DaftResult<()> { + #[tokio::test] + async fn test_csv_schema_local_empty_lines_skipped() -> DaftResult<()> { let file = format!( "{}/test/iris_tiny_empty_lines.csv", env!("CARGO_MANIFEST_DIR"), @@ -444,7 +444,8 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; + let (schema, read_stats) = + read_csv_schema(file.as_ref(), None, None, io_client, None).await?; assert_eq!( schema, Schema::new(vec![ @@ -461,15 +462,16 @@ mod tests { Ok(()) } - #[test] - fn test_csv_schema_local_nulls() -> DaftResult<()> { + #[tokio::test] + async fn test_csv_schema_local_nulls() -> DaftResult<()> { let file = format!("{}/test/iris_tiny_nulls.csv", env!("CARGO_MANIFEST_DIR"),); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; + let (schema, read_stats) = + read_csv_schema(file.as_ref(), None, None, io_client, None).await?; assert_eq!( schema, Schema::new(vec![ @@ -486,8 +488,8 @@ mod tests { Ok(()) } - #[test] - fn test_csv_schema_local_conflicting_types_utf8_fallback() -> DaftResult<()> { + #[tokio::test] + async fn test_csv_schema_local_conflicting_types_utf8_fallback() -> DaftResult<()> { let file = format!( "{}/test/iris_tiny_conflicting_dtypes.csv", env!("CARGO_MANIFEST_DIR"), @@ -497,7 +499,8 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; + let (schema, read_stats) = + read_csv_schema(file.as_ref(), None, None, io_client, None).await?; assert_eq!( schema, Schema::new(vec![ @@ -515,8 +518,8 @@ mod tests { Ok(()) } - #[test] - fn test_csv_schema_local_max_bytes() -> DaftResult<()> { + #[tokio::test] + async fn test_csv_schema_local_max_bytes() -> DaftResult<()> { let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); let mut io_config = IOConfig::default(); @@ -524,7 +527,7 @@ mod tests { let io_client = Arc::new(IOClient::new(io_config.into())?); let (schema, read_stats) = - read_csv_schema(file.as_ref(), None, Some(100), io_client, None)?; + read_csv_schema(file.as_ref(), None, Some(100), io_client, None).await?; assert_eq!( schema, Schema::new(vec![ @@ -550,8 +553,8 @@ mod tests { Ok(()) } - #[test] - fn test_csv_schema_local_invalid_column_header_mismatch() -> DaftResult<()> { + #[tokio::test] + async fn test_csv_schema_local_invalid_column_header_mismatch() -> DaftResult<()> { let file = format!( "{}/test/iris_tiny_invalid_header_cols_mismatch.csv", env!("CARGO_MANIFEST_DIR"), @@ -561,7 +564,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let err = read_csv_schema(file.as_ref(), None, None, io_client, None); + let err = read_csv_schema(file.as_ref(), None, None, io_client, None).await; assert!(err.is_err()); let err = err.unwrap_err(); assert!(matches!(err, DaftError::ArrowError(_)), "{}", err); @@ -575,8 +578,8 @@ mod tests { Ok(()) } - #[test] - fn test_csv_schema_local_invalid_no_header_variable_num_cols() -> DaftResult<()> { + #[tokio::test] + async fn test_csv_schema_local_invalid_no_header_variable_num_cols() -> DaftResult<()> { let file = format!( "{}/test/iris_tiny_invalid_no_header_variable_num_cols.csv", env!("CARGO_MANIFEST_DIR"), @@ -592,7 +595,8 @@ mod tests { None, io_client, None, - ); + ) + .await; assert!(err.is_err()); let err = err.unwrap_err(); assert!(matches!(err, DaftError::ArrowError(_)), "{}", err); @@ -607,7 +611,8 @@ mod tests { } #[rstest] - fn test_csv_schema_s3( + #[tokio::test] + async fn test_csv_schema_s3( #[values( // Uncompressed None, @@ -639,7 +644,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (schema, _) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; + let (schema, _) = read_csv_schema(file.as_ref(), None, None, io_client, None).await?; assert_eq!( schema, Schema::new(vec![ diff --git a/src/daft-csv/src/python.rs b/src/daft-csv/src/python.rs index 591ed8cb90..202eaf67d3 100644 --- a/src/daft-csv/src/python.rs +++ b/src/daft-csv/src/python.rs @@ -55,13 +55,20 @@ pub mod pylib { multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), )?; - let (schema, _) = crate::metadata::read_csv_schema( - uri, - parse_options, - max_bytes, - io_client, - Some(io_stats), - )?; + + let runtime = common_runtime::get_io_runtime(multithreaded_io.unwrap_or(true)); + + let (schema, _) = runtime.block_on_current_thread(async move { + crate::metadata::read_csv_schema( + uri, + parse_options, + max_bytes, + io_client, + Some(io_stats), + ) + .await + })?; + Ok(Arc::new(schema).into()) }) } diff --git a/src/daft-json/src/python.rs b/src/daft-json/src/python.rs index d09912aca9..b1bff56031 100644 --- a/src/daft-json/src/python.rs +++ b/src/daft-json/src/python.rs @@ -57,13 +57,20 @@ pub mod pylib { multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), )?; - let schema = crate::schema::read_json_schema( - uri, - parse_options, - max_bytes, - io_client, - Some(io_stats), - )?; + + let runtime_handle = common_runtime::get_io_runtime(true); + + let schema = runtime_handle.block_on_current_thread(async { + crate::schema::read_json_schema( + uri, + parse_options, + max_bytes, + io_client, + Some(io_stats), + ) + .await + })?; + Ok(Arc::new(schema).into()) }) } diff --git a/src/daft-json/src/schema.rs b/src/daft-json/src/schema.rs index e9af632c4e..320cd6f91e 100644 --- a/src/daft-json/src/schema.rs +++ b/src/daft-json/src/schema.rs @@ -49,25 +49,22 @@ impl Default for JsonReadStats { } } -pub fn read_json_schema( +pub async fn read_json_schema( uri: &str, parse_options: Option, max_bytes: Option, io_client: Arc, io_stats: Option, ) -> DaftResult { - let runtime_handle = get_io_runtime(true); - runtime_handle.block_on_current_thread(async { - read_json_schema_single( - uri, - parse_options.unwrap_or_default(), - // Default to 1 MiB. - max_bytes.or(Some(1024 * 1024)), - io_client, - io_stats, - ) - .await - }) + read_json_schema_single( + uri, + parse_options.unwrap_or_default(), + // Default to 1 MiB. + max_bytes.or(Some(1024 * 1024)), + io_client, + io_stats, + ) + .await } pub async fn read_json_schema_bulk( @@ -205,7 +202,8 @@ mod tests { use super::read_json_schema; #[rstest] - fn test_json_schema_local( + #[tokio::test] + async fn test_json_schema_local( #[values( // Uncompressed None, @@ -239,7 +237,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_json_schema(file.as_ref(), None, None, io_client, None)?; + let schema = read_json_schema(file.as_ref(), None, None, io_client, None).await?; assert_eq!( schema, Schema::new(vec![ @@ -255,7 +253,8 @@ mod tests { } #[rstest] - fn test_json_schema_local_dtypes() -> DaftResult<()> { + #[tokio::test] + async fn test_json_schema_local_dtypes() -> DaftResult<()> { let file = format!("{}/test/dtypes.jsonl", env!("CARGO_MANIFEST_DIR"),); let mut io_config = IOConfig::default(); @@ -263,7 +262,7 @@ mod tests { let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_json_schema(file.as_ref(), None, None, io_client, None)?; + let schema = read_json_schema(file.as_ref(), None, None, io_client, None).await?; assert_eq!( schema, Schema::new(vec![ @@ -316,15 +315,15 @@ mod tests { Ok(()) } - #[test] - fn test_json_schema_local_nulls() -> DaftResult<()> { + #[tokio::test] + async fn test_json_schema_local_nulls() -> DaftResult<()> { let file = format!("{}/test/iris_tiny_nulls.jsonl", env!("CARGO_MANIFEST_DIR"),); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_json_schema(file.as_ref(), None, None, io_client, None)?; + let schema = read_json_schema(file.as_ref(), None, None, io_client, None).await?; assert_eq!( schema, Schema::new(vec![ @@ -339,8 +338,8 @@ mod tests { Ok(()) } - #[test] - fn test_json_schema_local_conflicting_types_utf8_fallback() -> DaftResult<()> { + #[tokio::test] + async fn test_json_schema_local_conflicting_types_utf8_fallback() -> DaftResult<()> { let file = format!( "{}/test/iris_tiny_conflicting_dtypes.jsonl", env!("CARGO_MANIFEST_DIR"), @@ -350,7 +349,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_json_schema(file.as_ref(), None, None, io_client, None)?; + let schema = read_json_schema(file.as_ref(), None, None, io_client, None).await?; assert_eq!( schema, Schema::new(vec![ @@ -367,15 +366,15 @@ mod tests { Ok(()) } - #[test] - fn test_json_schema_local_max_bytes() -> DaftResult<()> { + #[tokio::test] + async fn test_json_schema_local_max_bytes() -> DaftResult<()> { let file = format!("{}/test/iris_tiny.jsonl", env!("CARGO_MANIFEST_DIR"),); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_json_schema(file.as_ref(), None, Some(100), io_client, None)?; + let schema = read_json_schema(file.as_ref(), None, Some(100), io_client, None).await?; assert_eq!( schema, Schema::new(vec![ @@ -391,7 +390,8 @@ mod tests { } #[rstest] - fn test_json_schema_s3( + #[tokio::test] + async fn test_json_schema_s3( #[values( // Uncompressed None, @@ -424,7 +424,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_json_schema(file.as_ref(), None, None, io_client, None)?; + let schema = read_json_schema(file.as_ref(), None, None, io_client, None).await?; assert_eq!( schema, Schema::new(vec![ diff --git a/src/daft-parquet/src/python.rs b/src/daft-parquet/src/python.rs index 23b627612e..036d09df02 100644 --- a/src/daft-parquet/src/python.rs +++ b/src/daft-parquet/src/python.rs @@ -240,17 +240,23 @@ pub mod pylib { multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), )?; - Ok(Arc::new( + + let runtime_handle = common_runtime::get_io_runtime(true); + + let task = async move { crate::read::read_parquet_schema( uri, io_client, Some(io_stats), schema_infer_options, None, // TODO: allow passing in of field_id_mapping through Python API? - )? - .0, - ) - .into()) + ) + .await + }; + + let (schema, _) = runtime_handle.block_on_current_thread(task)?; + + Ok(Arc::new(schema).into()) }) } diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index 73141bf7ac..6d3ee8b74b 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -954,17 +954,15 @@ pub fn read_parquet_into_pyarrow_bulk>( Ok(collected.into_iter().map(|(_, v)| v).collect()) } -pub fn read_parquet_schema( +pub async fn read_parquet_schema( uri: &str, io_client: Arc, io_stats: Option, schema_inference_options: ParquetSchemaInferenceOptions, field_id_mapping: Option>>, ) -> DaftResult<(Schema, FileMetaData)> { - let runtime_handle = get_io_runtime(true); - let builder = runtime_handle.block_on_current_thread(async { - ParquetReaderBuilder::from_uri(uri, io_client.clone(), io_stats, field_id_mapping).await - })?; + let builder = + ParquetReaderBuilder::from_uri(uri, io_client.clone(), io_stats, field_id_mapping).await?; let builder = builder.set_infer_schema_options(schema_inference_options); let metadata = builder.metadata; diff --git a/src/daft-scan/Cargo.toml b/src/daft-scan/Cargo.toml index 961d59fbff..12eeb71f5b 100644 --- a/src/daft-scan/Cargo.toml +++ b/src/daft-scan/Cargo.toml @@ -28,6 +28,9 @@ snafu = {workspace = true} typetag = {workspace = true} urlencoding = "2.1.3" +[dev-dependencies] +tokio = {workspace = true, features = ["full"]} + [features] python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-logical-plan/python", "daft-table/python", "daft-stats/python", "common-file-formats/python", "common-io-config/python", "common-daft-config/python", "common-scan-info/python", "daft-schema/python"] diff --git a/src/daft-scan/src/builder.rs b/src/daft-scan/src/builder.rs index fb5b6acd91..158c272495 100644 --- a/src/daft-scan/src/builder.rs +++ b/src/daft-scan/src/builder.rs @@ -97,7 +97,7 @@ impl ParquetScanBuilder { self } - pub fn finish(self) -> DaftResult { + pub async fn finish(self) -> DaftResult { let cfg = ParquetSourceConfig { coerce_int96_timestamp_unit: self.coerce_int96_timestamp_unit, field_id_mapping: self.field_id_mapping, @@ -105,17 +105,20 @@ impl ParquetScanBuilder { chunk_size: self.chunk_size, }; - let operator = Arc::new(GlobScanOperator::try_new( - self.glob_paths, - Arc::new(FileFormatConfig::Parquet(cfg)), - Arc::new(StorageConfig::Native(Arc::new( - NativeStorageConfig::new_internal(self.multithreaded, self.io_config), - ))), - self.infer_schema, - self.schema, - self.file_path_column, - self.hive_partitioning, - )?); + let operator = Arc::new( + GlobScanOperator::try_new( + self.glob_paths, + Arc::new(FileFormatConfig::Parquet(cfg)), + Arc::new(StorageConfig::Native(Arc::new( + NativeStorageConfig::new_internal(self.multithreaded, self.io_config), + ))), + self.infer_schema, + self.schema, + self.file_path_column, + self.hive_partitioning, + ) + .await?, + ); LogicalPlanBuilder::table_scan(ScanOperatorRef(operator), None) } @@ -238,7 +241,7 @@ impl CsvScanBuilder { self } - pub fn finish(self) -> DaftResult { + pub async fn finish(self) -> DaftResult { let cfg = CsvSourceConfig { delimiter: self.delimiter, has_headers: self.has_headers, @@ -251,17 +254,20 @@ impl CsvScanBuilder { chunk_size: self.chunk_size, }; - let operator = Arc::new(GlobScanOperator::try_new( - self.glob_paths, - Arc::new(FileFormatConfig::Csv(cfg)), - Arc::new(StorageConfig::Native(Arc::new( - NativeStorageConfig::new_internal(false, self.io_config), - ))), - self.infer_schema, - self.schema, - self.file_path_column, - self.hive_partitioning, - )?); + let operator = Arc::new( + GlobScanOperator::try_new( + self.glob_paths, + Arc::new(FileFormatConfig::Csv(cfg)), + Arc::new(StorageConfig::Native(Arc::new( + NativeStorageConfig::new_internal(false, self.io_config), + ))), + self.infer_schema, + self.schema, + self.file_path_column, + self.hive_partitioning, + ) + .await?, + ); LogicalPlanBuilder::table_scan(ScanOperatorRef(operator), None) } diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index a1331457fe..899e0ebc89 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -16,7 +16,7 @@ use daft_schema::{ }; use daft_stats::PartitionSpec; use daft_table::Table; -use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use futures::{stream::BoxStream, Stream, StreamExt, TryStreamExt}; use snafu::Snafu; use crate::{ @@ -73,32 +73,23 @@ impl From for DaftError { } } -type FileInfoIterator = Box>>; - -fn run_glob( +async fn run_glob( glob_path: &str, limit: Option, io_client: Arc, - runtime: RuntimeRef, io_stats: Option, file_format: FileFormat, -) -> DaftResult { +) -> DaftResult> + Send> { let (_, parsed_glob_path) = parse_url(glob_path)?; // Construct a static-lifetime BoxStream returning the FileMetadata let glob_input = parsed_glob_path.as_ref().to_string(); - let boxstream = runtime.block_on_current_thread(async move { - io_client - .glob(glob_input, None, None, limit, io_stats, Some(file_format)) - .await - })?; + let stream = io_client + .glob(glob_input, None, None, limit, io_stats, Some(file_format)) + .await?; - // Construct a static-lifetime BoxStreamIterator - let iterator = BoxStreamIterator { - boxstream, - runtime_handle: runtime.clone(), - }; - let iterator = iterator.map(|fm| Ok(fm?)); - Ok(Box::new(iterator)) + let stream = stream.map_err(|e| e.into()); + + Ok(stream) } fn run_glob_parallel( @@ -139,7 +130,7 @@ fn run_glob_parallel( } impl GlobScanOperator { - pub fn try_new( + pub async fn try_new( glob_paths: Vec, file_format_config: Arc, storage_config: Arc, @@ -157,7 +148,7 @@ impl GlobScanOperator { let file_format = file_format_config.file_format(); - let (io_runtime, io_client) = storage_config.get_io_client_and_runtime()?; + let (_, io_client) = storage_config.get_io_client_and_runtime()?; let io_stats = IOStatsContext::new(format!( "GlobScanOperator::try_new schema inference for {first_glob_path}" )); @@ -165,14 +156,15 @@ impl GlobScanOperator { first_glob_path, Some(1), io_client.clone(), - io_runtime, Some(io_stats.clone()), file_format, - )?; + ) + .await?; + let FileMetadata { filepath: first_filepath, .. - } = match paths.next() { + } = match paths.next().await { Some(file_metadata) => file_metadata, None => Err(Error::GlobNoMatch { glob_path: first_glob_path.to_string(), @@ -228,7 +220,8 @@ impl GlobScanOperator { ..Default::default() }, field_id_mapping.clone(), - )?; + ) + .await?; schema } @@ -256,16 +249,20 @@ impl GlobScanOperator { None, io_client, Some(io_stats), - )?; + ) + .await?; schema } - FileFormatConfig::Json(_) => daft_json::schema::read_json_schema( - first_filepath.as_str(), - None, - None, - io_client, - Some(io_stats), - )?, + FileFormatConfig::Json(_) => { + daft_json::schema::read_json_schema( + first_filepath.as_str(), + None, + None, + io_client, + Some(io_stats), + ) + .await? + } #[cfg(feature = "python")] FileFormatConfig::Database(_) => { return Err(DaftError::ValueError( diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index cce0a22230..788c0f3b60 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -813,7 +813,7 @@ mod test { ) } - fn make_glob_scan_operator(num_sources: usize) -> GlobScanOperator { + async fn make_glob_scan_operator(num_sources: usize) -> GlobScanOperator { let file_format_config: FileFormatConfig = FileFormatConfig::Parquet(ParquetSourceConfig { coerce_int96_timestamp_unit: TimeUnit::Seconds, field_id_mapping: None, @@ -838,14 +838,15 @@ mod test { None, false, ) + .await .unwrap(); glob_scan_operator } - #[test] - fn test_glob_display_condenses() -> DaftResult<()> { - let glob_scan_operator: GlobScanOperator = make_glob_scan_operator(8); + #[tokio::test] + async fn test_glob_display_condenses() -> DaftResult<()> { + let glob_scan_operator: GlobScanOperator = make_glob_scan_operator(8).await; let condensed_glob_paths: Vec = glob_scan_operator.multiline_display(); assert_eq!(condensed_glob_paths[1], "Glob paths = [../../tests/assets/parquet-data/mvp.parquet, ../../tests/assets/parquet-data/mvp.parquet, ../../tests/assets/parquet-data/mvp.parquet, ..., ../../tests/assets/parquet-data/mvp.parquet, ../../tests/assets/parquet-data/mvp.parquet, ../../tests/assets/parquet-data/mvp.parquet]"); Ok(()) diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 001476c18e..dd30b4541c 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -117,7 +117,9 @@ pub mod pylib { file_path_column: Option, ) -> PyResult { py.allow_threads(|| { - let operator = Arc::new(GlobScanOperator::try_new( + let executor = common_runtime::get_io_runtime(true); + + let task = GlobScanOperator::try_new( glob_path, file_format_config.into(), storage_config.into(), @@ -125,7 +127,11 @@ pub mod pylib { schema.map(|s| s.schema), file_path_column, hive_partitioning, - )?); + ); + + let operator = executor.block_on(task)??; + let operator = Arc::new(operator); + Ok(Self { scan_op: ScanOperatorRef(operator), }) diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index bb17a42854..6e45c23741 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -2,6 +2,7 @@ common-daft-config = {path = "../common/daft-config"} common-error = {path = "../common/error"} common-io-config = {path = "../common/io-config", default-features = false} +common-runtime = {workspace = true} daft-core = {path = "../daft-core"} daft-dsl = {path = "../daft-dsl"} daft-functions = {path = "../daft-functions"} diff --git a/src/daft-sql/src/table_provider/read_csv.rs b/src/daft-sql/src/table_provider/read_csv.rs index 543bd09948..241fad455f 100644 --- a/src/daft-sql/src/table_provider/read_csv.rs +++ b/src/daft-sql/src/table_provider/read_csv.rs @@ -103,6 +103,8 @@ impl SQLTableFunction for ReadCsvFunction { 1, // 1 positional argument (path) )?; - builder.finish().map_err(From::from) + let runtime = common_runtime::get_io_runtime(true); + let result = runtime.block_on(builder.finish())??; + Ok(result) } } diff --git a/src/daft-sql/src/table_provider/read_parquet.rs b/src/daft-sql/src/table_provider/read_parquet.rs index 643f771dcc..aeae0545ee 100644 --- a/src/daft-sql/src/table_provider/read_parquet.rs +++ b/src/daft-sql/src/table_provider/read_parquet.rs @@ -77,6 +77,9 @@ impl SQLTableFunction for ReadParquetFunction { 1, // 1 positional argument (path) )?; - builder.finish().map_err(From::from) + let runtime = common_runtime::get_io_runtime(true); + + let result = runtime.block_on(builder.finish())??; + Ok(result) } }