From 2a0c3455f3ea15ac511e82e656d84aef1922594b Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Mon, 30 Sep 2024 16:19:14 +1000 Subject: [PATCH] feat: Support `schema` arg in `read/scan_parquet()` --- crates/polars-io/src/parquet/read/mod.rs | 1 + crates/polars-io/src/parquet/read/options.rs | 4 +- crates/polars-io/src/parquet/read/reader.rs | 6 +- crates/polars-io/src/parquet/read/utils.rs | 34 ++++++++- crates/polars-lazy/src/scan/parquet.rs | 3 + .../src/executors/scan/parquet.rs | 12 +++- .../src/executors/sources/parquet.rs | 7 +- crates/polars-plan/src/plans/builder_dsl.rs | 2 + crates/polars-python/src/lazyframe/general.rs | 4 +- .../src/nodes/parquet_source/init.rs | 18 +---- .../nodes/parquet_source/metadata_fetch.rs | 17 ++--- .../nodes/parquet_source/metadata_utils.rs | 4 +- .../src/nodes/parquet_source/mod.rs | 15 ++++ py-polars/polars/io/parquet/functions.py | 17 +++++ py-polars/tests/unit/io/test_lazy_parquet.py | 72 +++++++++++++++++++ 15 files changed, 179 insertions(+), 37 deletions(-) diff --git a/crates/polars-io/src/parquet/read/mod.rs b/crates/polars-io/src/parquet/read/mod.rs index 5c722c5b77cd..1fec749af5ce 100644 --- a/crates/polars-io/src/parquet/read/mod.rs +++ b/crates/polars-io/src/parquet/read/mod.rs @@ -42,4 +42,5 @@ pub mod _internal { pub use super::mmap::to_deserializer; pub use super::predicates::read_this_row_group; pub use super::read_impl::{calc_prefilter_cost, PrefilterMaskSetting}; + pub use super::utils::ensure_matching_dtypes_if_found; } diff --git a/crates/polars-io/src/parquet/read/options.rs b/crates/polars-io/src/parquet/read/options.rs index 37357af89d67..6465620f6860 100644 --- a/crates/polars-io/src/parquet/read/options.rs +++ b/crates/polars-io/src/parquet/read/options.rs @@ -1,9 +1,11 @@ +use arrow::datatypes::ArrowSchemaRef; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, PartialEq, Eq, Copy, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct ParquetOptions { + pub schema: Option, pub parallel: ParallelStrategy, pub low_memory: bool, pub use_statistics: bool, diff --git a/crates/polars-io/src/parquet/read/reader.rs b/crates/polars-io/src/parquet/read/reader.rs index eb3609c127ac..2a70ef2c5046 100644 --- a/crates/polars-io/src/parquet/read/reader.rs +++ b/crates/polars-io/src/parquet/read/reader.rs @@ -15,7 +15,7 @@ pub use super::read_impl::BatchedParquetReader; use super::read_impl::{compute_row_group_range, read_parquet, FetchRowGroupsFromMmapReader}; #[cfg(feature = "cloud")] use super::utils::materialize_empty_df; -use super::utils::projected_arrow_schema_to_projection_indices; +use super::utils::{ensure_matching_dtypes_if_found, projected_arrow_schema_to_projection_indices}; #[cfg(feature = "cloud")] use crate::cloud::CloudOptions; use crate::mmap::MmapBytesReader; @@ -90,6 +90,8 @@ impl ParquetReader { allow_missing_columns: bool, ) -> PolarsResult { if allow_missing_columns { + // Must check the dtypes + ensure_matching_dtypes_if_found(first_schema, self.schema()?.as_ref())?; self.schema.replace(first_schema.clone()); } @@ -327,6 +329,8 @@ impl ParquetAsyncReader { allow_missing_columns: bool, ) -> PolarsResult { if allow_missing_columns { + // Must check the dtypes + ensure_matching_dtypes_if_found(first_schema, self.schema().await?.as_ref())?; self.schema.replace(first_schema.clone()); } diff --git a/crates/polars-io/src/parquet/read/utils.rs b/crates/polars-io/src/parquet/read/utils.rs index 7ce183088aee..87ce2600cbdc 100644 --- a/crates/polars-io/src/parquet/read/utils.rs +++ b/crates/polars-io/src/parquet/read/utils.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use polars_core::prelude::{ArrowSchema, DataFrame, DataType, Series, IDX_DTYPE}; +use polars_core::schema::SchemaNamesAndDtypes; use polars_error::{polars_bail, PolarsResult}; use crate::hive::materialize_hive_partitions; @@ -51,11 +52,40 @@ pub(super) fn projected_arrow_schema_to_projection_indices( let expected_dtype = DataType::from_arrow(&field.dtype, true); if dtype.clone() != expected_dtype { - polars_bail!(SchemaMismatch: "data type mismatch for column {}: found: {}, expected: {}", - &field.name, dtype, expected_dtype + polars_bail!(SchemaMismatch: "data type mismatch for column {}: expected: {}, found: {}", + &field.name, expected_dtype, dtype ) } } Ok((!is_full_ordered_projection).then_some(projection_indices)) } + +/// Utility to ensure the dtype of the column in `current_schema` matches the dtype in `schema` if +/// that column exists in `schema`. +pub fn ensure_matching_dtypes_if_found( + schema: &ArrowSchema, + current_schema: &ArrowSchema, +) -> PolarsResult<()> { + current_schema + .iter_names_and_dtypes() + .try_for_each(|(name, dtype)| { + if let Some(field) = schema.get(name) { + if dtype != &field.dtype { + // Check again with timezone normalization + // TODO: Add an ArrowDtype eq wrapper? + let lhs = DataType::from_arrow(dtype, true); + let rhs = DataType::from_arrow(&field.dtype, true); + + if lhs != rhs { + polars_bail!( + SchemaMismatch: + "dtypes differ for column {}: {:?} != {:?}" + , name, dtype, &field.dtype + ); + } + } + } + Ok(()) + }) +} diff --git a/crates/polars-lazy/src/scan/parquet.rs b/crates/polars-lazy/src/scan/parquet.rs index 382addea7ed1..179f56212bec 100644 --- a/crates/polars-lazy/src/scan/parquet.rs +++ b/crates/polars-lazy/src/scan/parquet.rs @@ -15,6 +15,7 @@ pub struct ScanArgsParquet { pub cloud_options: Option, pub hive_options: HiveOptions, pub use_statistics: bool, + pub schema: Option, pub low_memory: bool, pub rechunk: bool, pub cache: bool, @@ -33,6 +34,7 @@ impl Default for ScanArgsParquet { cloud_options: None, hive_options: Default::default(), use_statistics: true, + schema: None, rechunk: false, low_memory: false, cache: true, @@ -73,6 +75,7 @@ impl LazyFileListReader for LazyParquetReader { self.args.low_memory, self.args.cloud_options, self.args.use_statistics, + self.args.schema.as_deref(), self.args.hive_options, self.args.glob, self.args.include_file_paths, diff --git a/crates/polars-mem-engine/src/executors/scan/parquet.rs b/crates/polars-mem-engine/src/executors/scan/parquet.rs index 49b01c471610..6661e4c5a3cc 100644 --- a/crates/polars-mem-engine/src/executors/scan/parquet.rs +++ b/crates/polars-mem-engine/src/executors/scan/parquet.rs @@ -62,7 +62,11 @@ impl ParquetExec { // Modified if we have a negative slice let mut first_source = 0; - let first_schema = self.file_info.reader_schema.clone().unwrap().unwrap_left(); + let first_schema = self + .options + .schema + .clone() + .unwrap_or_else(|| self.file_info.reader_schema.clone().unwrap().unwrap_left()); let projected_arrow_schema = { if let Some(with_columns) = self.file_options.with_columns.as_deref() { @@ -258,7 +262,11 @@ impl ParquetExec { eprintln!("POLARS PREFETCH_SIZE: {}", batch_size) } - let first_schema = self.file_info.reader_schema.clone().unwrap().unwrap_left(); + let first_schema = self + .options + .schema + .clone() + .unwrap_or_else(|| self.file_info.reader_schema.clone().unwrap().unwrap_left()); let projected_arrow_schema = { if let Some(with_columns) = self.file_options.with_columns.as_deref() { diff --git a/crates/polars-pipe/src/executors/sources/parquet.rs b/crates/polars-pipe/src/executors/sources/parquet.rs index efe3edac1b87..1ab80372b9a4 100644 --- a/crates/polars-pipe/src/executors/sources/parquet.rs +++ b/crates/polars-pipe/src/executors/sources/parquet.rs @@ -81,7 +81,7 @@ impl ParquetSource { .as_paths() .ok_or_else(|| polars_err!(nyi = "Streaming scanning of in-memory buffers"))?; let path = &paths[index]; - let options = self.options; + let options = self.options.clone(); let file_options = self.file_options.clone(); let hive_partitions = self @@ -261,7 +261,10 @@ impl ParquetSource { } let run_async = paths.first().map(is_cloud_url).unwrap_or(false) || config::force_async(); - let first_schema = file_info.reader_schema.clone().unwrap().unwrap_left(); + let first_schema = options + .schema + .clone() + .unwrap_or_else(|| file_info.reader_schema.clone().unwrap().unwrap_left()); let projected_arrow_schema = { if let Some(with_columns) = file_options.with_columns.as_deref() { diff --git a/crates/polars-plan/src/plans/builder_dsl.rs b/crates/polars-plan/src/plans/builder_dsl.rs index bc695b47a035..fb82585ae476 100644 --- a/crates/polars-plan/src/plans/builder_dsl.rs +++ b/crates/polars-plan/src/plans/builder_dsl.rs @@ -85,6 +85,7 @@ impl DslBuilder { low_memory: bool, cloud_options: Option, use_statistics: bool, + schema: Option<&Schema>, hive_options: HiveOptions, glob: bool, include_file_paths: Option, @@ -108,6 +109,7 @@ impl DslBuilder { file_options: options, scan_type: FileScan::Parquet { options: ParquetOptions { + schema: schema.map(|x| Arc::new(x.to_arrow(CompatLevel::newest()))), parallel, low_memory, use_statistics, diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 18b9323388e7..b5e9bfc7bd50 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -240,7 +240,7 @@ impl PyLazyFrame { #[cfg(feature = "parquet")] #[staticmethod] #[pyo3(signature = (source, sources, n_rows, cache, parallel, rechunk, row_index, - low_memory, cloud_options, use_statistics, hive_partitioning, hive_schema, try_parse_hive_dates, retries, glob, include_file_paths, allow_missing_columns) + low_memory, cloud_options, use_statistics, hive_partitioning, schema, hive_schema, try_parse_hive_dates, retries, glob, include_file_paths, allow_missing_columns) )] fn new_from_parquet( source: Option, @@ -254,6 +254,7 @@ impl PyLazyFrame { cloud_options: Option>, use_statistics: bool, hive_partitioning: Option, + schema: Option>, hive_schema: Option>, try_parse_hive_dates: bool, retries: usize, @@ -285,6 +286,7 @@ impl PyLazyFrame { low_memory, cloud_options: None, use_statistics, + schema: schema.map(|x| Arc::new(x.0)), hive_options, glob, include_file_paths: include_file_paths.map(|x| x.into()), diff --git a/crates/polars-stream/src/nodes/parquet_source/init.rs b/crates/polars-stream/src/nodes/parquet_source/init.rs index 07f1a55cdee2..3187bbe797e4 100644 --- a/crates/polars-stream/src/nodes/parquet_source/init.rs +++ b/crates/polars-stream/src/nodes/parquet_source/init.rs @@ -85,14 +85,7 @@ impl ParquetSourceNode { ); } - let reader_schema = self - .file_info - .reader_schema - .as_ref() - .unwrap() - .as_ref() - .unwrap_left() - .clone(); + let reader_schema = self.schema.clone().unwrap(); let (normalized_slice_oneshot_rx, metadata_rx, metadata_task_handle) = self.init_metadata_fetcher(); @@ -361,14 +354,7 @@ impl ParquetSourceNode { } pub(super) fn init_projected_arrow_schema(&mut self) { - let reader_schema = self - .file_info - .reader_schema - .as_ref() - .unwrap() - .as_ref() - .unwrap_left() - .clone(); + let reader_schema = self.schema.clone().unwrap(); self.projected_arrow_schema = Some( if let Some(columns) = self.file_options.with_columns.as_deref() { diff --git a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs index 0bee88861769..746c517ce744 100644 --- a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use futures::StreamExt; use polars_error::{polars_bail, PolarsResult}; use polars_io::prelude::FileMetadata; +use polars_io::prelude::_internal::ensure_matching_dtypes_if_found; use polars_io::utils::byte_source::{DynByteSource, MemSliceByteSource}; use polars_io::utils::slice::SplitSlicePosition; use polars_utils::mmap::MemSlice; @@ -106,14 +107,7 @@ impl ParquetSourceNode { }; let first_metadata = self.first_metadata.clone(); - let reader_schema_len = self - .file_info - .reader_schema - .as_ref() - .unwrap() - .as_ref() - .unwrap_left() - .len(); + let first_schema = self.schema.clone().unwrap(); let has_projection = self.file_options.with_columns.is_some(); let allow_missing_columns = self.file_options.allow_missing_columns; @@ -121,6 +115,7 @@ impl ParquetSourceNode { move |handle: task_handles_ext::AbortOnDropHandle< PolarsResult<(usize, Arc, MemSlice)>, >| { + let first_schema = first_schema.clone(); let projected_arrow_schema = projected_arrow_schema.clone(); let first_metadata = first_metadata.clone(); // Run on CPU runtime - metadata deserialization is expensive, especially @@ -138,14 +133,16 @@ impl ParquetSourceNode { let schema = polars_parquet::arrow::read::infer_schema(&metadata)?; - if !has_projection && schema.len() > reader_schema_len { + if !has_projection && schema.len() > first_schema.len() { polars_bail!( SchemaMismatch: "parquet file contained extra columns and no selection was given" ) } - if !allow_missing_columns { + if allow_missing_columns { + ensure_matching_dtypes_if_found(&first_schema, &schema)?; + } else { ensure_schema_has_projected_fields( &schema, projected_arrow_schema.as_ref(), diff --git a/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs b/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs index 61db45d54a0a..d4646bf2e1a7 100644 --- a/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs +++ b/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs @@ -136,8 +136,8 @@ pub(super) fn ensure_schema_has_projected_fields( }; if dtype != expected_dtype { - polars_bail!(SchemaMismatch: "data type mismatch for column {}: found: {}, expected: {}", - &field.name, dtype, expected_dtype + polars_bail!(SchemaMismatch: "data type mismatch for column {}: expected: {}, found: {}", + &field.name, expected_dtype, dtype ) } } diff --git a/crates/polars-stream/src/nodes/parquet_source/mod.rs b/crates/polars-stream/src/nodes/parquet_source/mod.rs index c9d2538db714..4a09e6fca1e0 100644 --- a/crates/polars-stream/src/nodes/parquet_source/mod.rs +++ b/crates/polars-stream/src/nodes/parquet_source/mod.rs @@ -47,6 +47,7 @@ pub struct ParquetSourceNode { config: Config, verbose: bool, physical_predicate: Option>, + schema: Option>, projected_arrow_schema: Option>, byte_source_builder: DynByteSourceBuilder, memory_prefetch_func: fn(&[u8]) -> (), @@ -112,6 +113,7 @@ impl ParquetSourceNode { }, verbose, physical_predicate: None, + schema: None, projected_arrow_schema: None, byte_source_builder, memory_prefetch_func, @@ -154,6 +156,19 @@ impl ComputeNode for ParquetSourceNode { eprintln!("[ParquetSource]: {:?}", &self.config); } + self.schema = Some( + self.options + .schema + .take() + .unwrap_or_else(|| self.file_info.reader_schema.take().unwrap().unwrap_left()), + ); + + { + // Ensure these are not used anymore + self.options.schema.take(); + self.file_info.reader_schema.take(); + } + self.init_projected_arrow_schema(); self.physical_predicate = self.predicate.clone().map(phys_expr_to_io_expr); diff --git a/py-polars/polars/io/parquet/functions.py b/py-polars/polars/io/parquet/functions.py index cfb765670054..c634bc86f9a3 100644 --- a/py-polars/polars/io/parquet/functions.py +++ b/py-polars/polars/io/parquet/functions.py @@ -51,6 +51,7 @@ def read_parquet( use_statistics: bool = True, hive_partitioning: bool | None = None, glob: bool = True, + schema: SchemaDict | None = None, hive_schema: SchemaDict | None = None, try_parse_hive_dates: bool = True, rechunk: bool = False, @@ -102,6 +103,10 @@ def read_parquet( disabled. glob Expand path given via globbing rules. + schema + Specify the datatypes of the columns. The datatypes must match the + datatypes in the file(s). If there are extra columns that are not in the + file(s), consider also enabling `allow_missing_columns`. hive_schema The column names and data types of the columns by which the data is partitioned. If set to `None` (default), the schema of the Hive partitions is inferred. @@ -172,6 +177,9 @@ def read_parquet( if include_file_paths is not None: msg = "`include_file_paths` cannot be used with `use_pyarrow=True`" raise ValueError(msg) + if schema is not None: + msg = "`schema` cannot be used with `use_pyarrow=True`" + raise ValueError(msg) if hive_schema is not None: msg = ( "cannot use `hive_partitions` with `use_pyarrow=True`" @@ -203,6 +211,7 @@ def read_parquet( parallel=parallel, use_statistics=use_statistics, hive_partitioning=hive_partitioning, + schema=schema, hive_schema=hive_schema, try_parse_hive_dates=try_parse_hive_dates, rechunk=rechunk, @@ -314,6 +323,7 @@ def scan_parquet( use_statistics: bool = True, hive_partitioning: bool | None = None, glob: bool = True, + schema: SchemaDict | None = None, hive_schema: SchemaDict | None = None, try_parse_hive_dates: bool = True, rechunk: bool = False, @@ -370,6 +380,10 @@ def scan_parquet( to prune reads. glob Expand path given via globbing rules. + schema + Specify the datatypes of the columns. The datatypes must match the + datatypes in the file(s). If there are extra columns that are not in the + file(s), consider also enabling `allow_missing_columns`. hive_schema The column names and data types of the columns by which the data is partitioned. If set to `None` (default), the schema of the Hive partitions is inferred. @@ -456,6 +470,7 @@ def scan_parquet( low_memory=low_memory, use_statistics=use_statistics, hive_partitioning=hive_partitioning, + schema=schema, hive_schema=hive_schema, try_parse_hive_dates=try_parse_hive_dates, retries=retries, @@ -479,6 +494,7 @@ def _scan_parquet_impl( use_statistics: bool = True, hive_partitioning: bool | None = None, glob: bool = True, + schema: SchemaDict | None = None, hive_schema: SchemaDict | None = None, try_parse_hive_dates: bool = True, retries: int = 2, @@ -509,6 +525,7 @@ def _scan_parquet_impl( cloud_options=storage_options, use_statistics=use_statistics, hive_partitioning=hive_partitioning, + schema=schema, hive_schema=hive_schema, try_parse_hive_dates=try_parse_hive_dates, retries=retries, diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py index 9604793c7baf..45179a1520d7 100644 --- a/py-polars/tests/unit/io/test_lazy_parquet.py +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -647,3 +647,75 @@ def test_parquet_unaligned_schema_read_missing_cols_from_first( match="did not find column in file: a", ): lf.collect(streaming=streaming) + + +@pytest.mark.parametrize("parallel", ["columns", "row_groups", "prefiltered", "none"]) +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.write_disk +def test_parquet_schema_arg( + tmp_path: Path, + parallel: ParallelStrategy, + streaming: bool, +) -> None: + tmp_path.mkdir(exist_ok=True) + dfs = [pl.DataFrame({"a": 1, "b": 1}), pl.DataFrame({"a": 2, "b": 2})] + paths = [tmp_path / "1", tmp_path / "2"] + + for df, path in zip(dfs, paths): + df.write_parquet(path) + + schema: dict[str, pl.DataType] = { + "1": pl.Datetime(time_unit="ms", time_zone="CET"), + "a": pl.Int64(), + "b": pl.Int64(), + } + + # Test `schema` containing an extra column. + + lf = pl.scan_parquet(paths, parallel=parallel, schema=schema) + + with pytest.raises(pl.exceptions.ColumnNotFoundError): + lf.collect(streaming=streaming) + + lf = pl.scan_parquet( + paths, parallel=parallel, schema=schema, allow_missing_columns=True + ) + + assert_frame_equal( + lf.collect(streaming=streaming), + pl.DataFrame({"1": None, "a": [1, 2], "b": [1, 2]}, schema=schema), + ) + + # Just one test that `read_parquet` is propagating this argument. + assert_frame_equal( + pl.read_parquet( + paths, parallel=parallel, schema=schema, allow_missing_columns=True + ), + pl.DataFrame({"1": None, "a": [1, 2], "b": [1, 2]}, schema=schema), + ) + + # Test files containing extra columns not in `schema` + + schema: dict[str, type[pl.DataType]] = {"a": pl.Int64} # type: ignore[no-redef] + + lf = pl.scan_parquet(paths, parallel=parallel, schema=schema) + + with pytest.raises(pl.exceptions.SchemaError, match="file contained extra columns"): + lf.collect(streaming=streaming) + + lf = pl.scan_parquet(paths, parallel=parallel, schema=schema).select("a") + + assert_frame_equal( + lf.collect(streaming=streaming), + pl.DataFrame({"a": [1, 2]}, schema=schema), + ) + + schema: dict[str, type[pl.DataType]] = {"a": pl.Int64, "b": pl.Int8} # type: ignore[no-redef] + + lf = pl.scan_parquet(paths, parallel=parallel, schema=schema) + + with pytest.raises( + pl.exceptions.SchemaError, + match="data type mismatch for column b: expected: i8, found: i64", + ): + lf.collect(streaming=streaming)