diff --git a/Cargo.toml b/Cargo.toml index a48f8c7894..9fe8c44bb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ debug = true debug = "line-tables-only" [workspace.dependencies] -delta_kernel = { version = "0.4.1", features = ["sync-engine"] } +delta_kernel = { version = "0.4.1", features = ["default-engine"] } #delta_kernel = { path = "../delta-kernel-rs/kernel", features = ["sync-engine"] } # arrow diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 948139dcc1..57a9496070 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deltalake-core" -version = "0.22.1" +version = "0.22.3" authors.workspace = true keywords.workspace = true readme.workspace = true diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 4425b0ff6f..034781b85c 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -826,9 +826,12 @@ impl TableProvider for DeltaTableProvider { fn supports_filters_pushdown( &self, - _filter: &[&Expr], + filter: &[&Expr], ) -> DataFusionResult> { - Ok(vec![TableProviderFilterPushDown::Inexact]) + Ok(filter + .iter() + .map(|_| TableProviderFilterPushDown::Inexact) + .collect()) } fn statistics(&self) -> Option { @@ -1150,11 +1153,12 @@ pub(crate) async fn execute_plan_to_batch( Ok(concat_batches(&plan.schema(), data.iter())?) } -/// Responsible for checking batches of data conform to table's invariants. -#[derive(Clone)] +/// Responsible for checking batches of data conform to table's invariants, constraints and nullability. +#[derive(Clone, Default)] pub struct DeltaDataChecker { constraints: Vec, invariants: Vec, + non_nullable_columns: Vec, ctx: SessionContext, } @@ -1164,6 +1168,7 @@ impl DeltaDataChecker { Self { invariants: vec![], constraints: vec![], + non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } } @@ -1173,6 +1178,7 @@ impl DeltaDataChecker { Self { invariants, constraints: vec![], + non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } } @@ -1182,6 +1188,7 @@ impl DeltaDataChecker { Self { constraints, invariants: vec![], + non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } } @@ -1202,9 +1209,21 @@ impl DeltaDataChecker { pub fn new(snapshot: &DeltaTableState) -> Self { let invariants = snapshot.schema().get_invariants().unwrap_or_default(); let constraints = snapshot.table_config().get_constraints(); + let non_nullable_columns = snapshot + .schema() + .fields() + .filter_map(|f| { + if !f.is_nullable() { + Some(f.name().clone()) + } else { + None + } + }) + .collect_vec(); Self { invariants, constraints, + non_nullable_columns, ctx: DeltaSessionContext::default().into(), } } @@ -1214,10 +1233,35 @@ impl DeltaDataChecker { /// If it does not, it will return [DeltaTableError::InvalidData] with a list /// of values that violated each invariant. pub async fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { + self.check_nullability(record_batch)?; self.enforce_checks(record_batch, &self.invariants).await?; self.enforce_checks(record_batch, &self.constraints).await } + /// Return true if all the nullability checks are valid + fn check_nullability(&self, record_batch: &RecordBatch) -> Result { + let mut violations = Vec::new(); + for col in self.non_nullable_columns.iter() { + if let Some(arr) = record_batch.column_by_name(col) { + if arr.null_count() > 0 { + violations.push(format!( + "Non-nullable column violation for {col}, found {} null values", + arr.null_count() + )); + } + } else { + violations.push(format!( + "Non-nullable column violation for {col}, not found in batch!" + )); + } + } + if !violations.is_empty() { + Err(DeltaTableError::InvalidData { violations }) + } else { + Ok(true) + } + } + async fn enforce_checks( &self, record_batch: &RecordBatch, @@ -2598,4 +2642,38 @@ mod tests { assert_eq!(actual.len(), 0); } + + #[tokio::test] + async fn test_check_nullability() -> DeltaResult<()> { + use arrow::array::StringArray; + + let data_checker = DeltaDataChecker { + non_nullable_columns: vec!["zed".to_string(), "yap".to_string()], + ..Default::default() + }; + + let arr: Arc = Arc::new(StringArray::from(vec!["s"])); + let nulls: Arc = Arc::new(StringArray::new_null(1)); + let batch = RecordBatch::try_from_iter(vec![("a", arr), ("zed", nulls)]).unwrap(); + + let result = data_checker.check_nullability(&batch); + assert!( + result.is_err(), + "The result should have errored! {result:?}" + ); + + let arr: Arc = Arc::new(StringArray::from(vec!["s"])); + let batch = RecordBatch::try_from_iter(vec![("zed", arr)]).unwrap(); + let result = data_checker.check_nullability(&batch); + assert!( + result.is_err(), + "The result should have errored! {result:?}" + ); + + let arr: Arc = Arc::new(StringArray::from(vec!["s"])); + let batch = RecordBatch::try_from_iter(vec![("zed", arr.clone()), ("yap", arr)]).unwrap(); + let _ = data_checker.check_nullability(&batch)?; + + Ok(()) + } } diff --git a/crates/core/src/kernel/snapshot/log_data.rs b/crates/core/src/kernel/snapshot/log_data.rs index 562ca3da90..f22f88ad23 100644 --- a/crates/core/src/kernel/snapshot/log_data.rs +++ b/crates/core/src/kernel/snapshot/log_data.rs @@ -245,23 +245,14 @@ impl LogicalFile<'_> { /// Defines a deletion vector pub fn deletion_vector(&self) -> Option> { - if let Some(arr) = self.deletion_vector.as_ref() { - // With v0.22 and the upgrade to a more recent arrow. Reading nullable structs with - // non-nullable entries back out of parquet is resulting in the DeletionVector having - // an empty string rather than a null. The addition check on the value ensures that a - // [DeletionVectorView] is not created in this scenario - // - // - if arr.storage_type.is_valid(self.index) - && !arr.storage_type.value(self.index).is_empty() - { - return Some(DeletionVectorView { + self.deletion_vector.as_ref().and_then(|arr| { + arr.storage_type + .is_valid(self.index) + .then_some(DeletionVectorView { data: arr, index: self.index, - }); - } - } - None + }) + }) } /// The number of records stored in the data file. @@ -380,18 +371,23 @@ impl<'a> FileStatsAccessor<'a> { ); let deletion_vector = extract_and_cast_opt::(data, "add.deletionVector"); let deletion_vector = deletion_vector.and_then(|dv| { - let storage_type = extract_and_cast::(dv, "storageType").ok()?; - let path_or_inline_dv = extract_and_cast::(dv, "pathOrInlineDv").ok()?; - let size_in_bytes = extract_and_cast::(dv, "sizeInBytes").ok()?; - let cardinality = extract_and_cast::(dv, "cardinality").ok()?; - let offset = extract_and_cast_opt::(dv, "offset"); - Some(DeletionVector { - storage_type, - path_or_inline_dv, - size_in_bytes, - cardinality, - offset, - }) + if dv.null_count() == dv.len() { + None + } else { + let storage_type = extract_and_cast::(dv, "storageType").ok()?; + let path_or_inline_dv = + extract_and_cast::(dv, "pathOrInlineDv").ok()?; + let size_in_bytes = extract_and_cast::(dv, "sizeInBytes").ok()?; + let cardinality = extract_and_cast::(dv, "cardinality").ok()?; + let offset = extract_and_cast_opt::(dv, "offset"); + Some(DeletionVector { + storage_type, + path_or_inline_dv, + size_in_bytes, + cardinality, + offset, + }) + } }); Ok(Self { diff --git a/crates/core/src/kernel/snapshot/parse.rs b/crates/core/src/kernel/snapshot/parse.rs index f75744691e..e8630cbe0c 100644 --- a/crates/core/src/kernel/snapshot/parse.rs +++ b/crates/core/src/kernel/snapshot/parse.rs @@ -78,6 +78,10 @@ pub(super) fn read_adds(array: &dyn ProvidesColumnByName) -> DeltaResult(array, "add") { + // Stop early if all values are null + if arr.null_count() == arr.len() { + return Ok(vec![]); + } let path = ex::extract_and_cast::(arr, "path")?; let pvs = ex::extract_and_cast_opt::(arr, "partitionValues"); let size = ex::extract_and_cast::(arr, "size")?; @@ -94,22 +98,33 @@ pub(super) fn read_adds(array: &dyn ProvidesColumnByName) -> DeltaResult(d, "sizeInBytes")?; let cardinality = ex::extract_and_cast::(d, "cardinality")?; - Box::new(|idx: usize| { - if ex::read_str(storage_type, idx).is_ok() { - Some(DeletionVectorDescriptor { - storage_type: std::str::FromStr::from_str( - ex::read_str(storage_type, idx).ok()?, - ) - .ok()?, - path_or_inline_dv: ex::read_str(path_or_inline_dv, idx).ok()?.to_string(), - offset: ex::read_primitive_opt(offset, idx), - size_in_bytes: ex::read_primitive(size_in_bytes, idx).ok()?, - cardinality: ex::read_primitive(cardinality, idx).ok()?, - }) - } else { - None - } - }) + // Column might exist but have nullability set for the whole array, so we just return Nones + if d.null_count() == d.len() { + Box::new(|_| None) + } else { + Box::new(|idx: usize| { + d.is_valid(idx) + .then(|| { + if ex::read_str(storage_type, idx).is_ok() { + Some(DeletionVectorDescriptor { + storage_type: std::str::FromStr::from_str( + ex::read_str(storage_type, idx).ok()?, + ) + .ok()?, + path_or_inline_dv: ex::read_str(path_or_inline_dv, idx) + .ok()? + .to_string(), + offset: ex::read_primitive_opt(offset, idx), + size_in_bytes: ex::read_primitive(size_in_bytes, idx).ok()?, + cardinality: ex::read_primitive(cardinality, idx).ok()?, + }) + } else { + None + } + }) + .flatten() + }) + } } else { Box::new(|_| None) }; @@ -210,22 +225,33 @@ pub(super) fn read_removes(array: &dyn ProvidesColumnByName) -> DeltaResult(d, "sizeInBytes")?; let cardinality = ex::extract_and_cast::(d, "cardinality")?; - Box::new(|idx: usize| { - if ex::read_str(storage_type, idx).is_ok() { - Some(DeletionVectorDescriptor { - storage_type: std::str::FromStr::from_str( - ex::read_str(storage_type, idx).ok()?, - ) - .ok()?, - path_or_inline_dv: ex::read_str(path_or_inline_dv, idx).ok()?.to_string(), - offset: ex::read_primitive_opt(offset, idx), - size_in_bytes: ex::read_primitive(size_in_bytes, idx).ok()?, - cardinality: ex::read_primitive(cardinality, idx).ok()?, - }) - } else { - None - } - }) + // Column might exist but have nullability set for the whole array, so we just return Nones + if d.null_count() == d.len() { + Box::new(|_| None) + } else { + Box::new(|idx: usize| { + d.is_valid(idx) + .then(|| { + if ex::read_str(storage_type, idx).is_ok() { + Some(DeletionVectorDescriptor { + storage_type: std::str::FromStr::from_str( + ex::read_str(storage_type, idx).ok()?, + ) + .ok()?, + path_or_inline_dv: ex::read_str(path_or_inline_dv, idx) + .ok()? + .to_string(), + offset: ex::read_primitive_opt(offset, idx), + size_in_bytes: ex::read_primitive(size_in_bytes, idx).ok()?, + cardinality: ex::read_primitive(cardinality, idx).ok()?, + }) + } else { + None + } + }) + .flatten() + }) + } } else { Box::new(|_| None) }; diff --git a/crates/core/src/kernel/snapshot/replay.rs b/crates/core/src/kernel/snapshot/replay.rs index 1b18b61bc7..6267a7f3be 100644 --- a/crates/core/src/kernel/snapshot/replay.rs +++ b/crates/core/src/kernel/snapshot/replay.rs @@ -20,7 +20,7 @@ use hashbrown::HashSet; use itertools::Itertools; use percent_encoding::percent_decode_str; use pin_project_lite::pin_project; -use tracing::debug; +use tracing::log::*; use super::parse::collect_map; use super::ReplayVisitor; @@ -440,6 +440,14 @@ pub(super) struct DVInfo<'a> { fn seen_key(info: &FileInfo<'_>) -> String { let path = percent_decode_str(info.path).decode_utf8_lossy(); if let Some(dv) = &info.dv { + // If storage_type is empty then delta-rs has somehow gotten an empty rather than a null + // deletion vector, oooof + // + // See #3030 + if dv.storage_type.is_empty() { + warn!("An empty but not nullable deletionVector was seen for {info:?}"); + return path.to_string(); + } if let Some(offset) = &dv.offset { format!( "{}::{}{}@{offset}", @@ -551,22 +559,32 @@ fn read_file_info<'a>(arr: &'a dyn ProvidesColumnByName) -> DeltaResult(d, "pathOrInlineDv")?; let offset = ex::extract_and_cast::(d, "offset")?; - Box::new(|idx: usize| { - if ex::read_str(storage_type, idx).is_ok() { - Ok(Some(DVInfo { - storage_type: ex::read_str(storage_type, idx)?, - path_or_inline_dv: ex::read_str(path_or_inline_dv, idx)?, - offset: ex::read_primitive_opt(offset, idx), - })) - } else { - Ok(None) - } - }) + // Column might exist but have nullability set for the whole array, so we just return Nones + if d.null_count() == d.len() { + Box::new(|_| Ok(None)) + } else { + Box::new(|idx: usize| { + if d.is_valid(idx) { + if ex::read_str(storage_type, idx).is_ok() { + Ok(Some(DVInfo { + storage_type: ex::read_str(storage_type, idx)?, + path_or_inline_dv: ex::read_str(path_or_inline_dv, idx)?, + offset: ex::read_primitive_opt(offset, idx), + })) + } else { + Ok(None) + } + } else { + Ok(None) + } + }) + } } else { Box::new(|_| Ok(None)) }; let mut adds = Vec::with_capacity(path.len()); + for idx in 0..path.len() { let value = path .is_valid(idx) @@ -579,6 +597,7 @@ fn read_file_info<'a>(arr: &'a dyn ProvidesColumnByName) -> DeltaResult + #[cfg(feature = "datafusion")] #[tokio::test] async fn test_create_checkpoint_overwrite() -> DeltaResult<()> { use crate::protocol::SaveMode; + use crate::writer::test_utils::datafusion::get_data_sorted; use crate::writer::test_utils::get_arrow_schema; + use datafusion::assert_batches_sorted_eq; + + let tmp_dir = tempfile::tempdir().unwrap(); + let tmp_path = std::fs::canonicalize(tmp_dir.path()).unwrap(); let batch = RecordBatch::try_new( Arc::clone(&get_arrow_schema(&None)), @@ -1177,13 +1183,15 @@ mod tests { ], ) .unwrap(); - let table = DeltaOps::try_from_uri_with_storage_options("memory://", HashMap::default()) + + let mut table = DeltaOps::try_from_uri(tmp_path.as_os_str().to_str().unwrap()) .await? .write(vec![batch]) .await?; + table.load().await?; assert_eq!(table.version(), 0); - create_checkpoint_for(0, table.snapshot().unwrap(), table.log_store.as_ref()).await?; + create_checkpoint(&table).await?; let batch = RecordBatch::try_new( Arc::clone(&get_arrow_schema(&None)), @@ -1194,11 +1202,23 @@ mod tests { ], ) .unwrap(); - let table = DeltaOps(table) + + let table = DeltaOps::try_from_uri(tmp_path.as_os_str().to_str().unwrap()) + .await? .write(vec![batch]) .with_save_mode(SaveMode::Overwrite) .await?; assert_eq!(table.version(), 1); + + let expected = [ + "+----+-------+------------+", + "| id | value | modified |", + "+----+-------+------------+", + "| A | 0 | 2021-02-02 |", + "+----+-------+------------+", + ]; + let actual = get_data_sorted(&table, "id,value,modified").await; + assert_batches_sorted_eq!(&expected, &actual); Ok(()) } } diff --git a/crates/deltalake/Cargo.toml b/crates/deltalake/Cargo.toml index 3c5a13172e..476f0b5d60 100644 --- a/crates/deltalake/Cargo.toml +++ b/crates/deltalake/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deltalake" -version = "0.22.1" +version = "0.22.3" authors.workspace = true keywords.workspace = true readme.workspace = true @@ -16,7 +16,7 @@ rust-version.workspace = true features = ["azure", "datafusion", "gcs", "hdfs", "json", "python", "s3", "unity-experimental"] [dependencies] -deltalake-core = { version = "0.22.0", path = "../core" } +deltalake-core = { version = "0.22.2", path = "../core" } deltalake-aws = { version = "0.5.0", path = "../aws", default-features = false, optional = true } deltalake-azure = { version = "0.5.0", path = "../azure", optional = true } deltalake-gcp = { version = "0.6.0", path = "../gcp", optional = true } diff --git a/python/Cargo.toml b/python/Cargo.toml index c89f68b8c0..bb6fbba621 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deltalake-python" -version = "0.22.1" +version = "0.22.3" authors = ["Qingping Hou ", "Will Jones "] homepage = "https://github.com/delta-io/delta-rs" license = "Apache-2.0" diff --git a/python/deltalake/__init__.py b/python/deltalake/__init__.py index 43997076b2..2e82d20a40 100644 --- a/python/deltalake/__init__.py +++ b/python/deltalake/__init__.py @@ -2,6 +2,7 @@ from ._internal import __version__ as __version__ from ._internal import rust_core_version as rust_core_version from .data_catalog import DataCatalog as DataCatalog +from .query import QueryBuilder as QueryBuilder from .schema import DataType as DataType from .schema import Field as Field from .schema import Schema as Schema diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 026d84d08d..052cf1ebb6 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -874,6 +874,11 @@ class DeltaFileSystemHandler: ) -> ObjectOutputStream: """Open an output stream for sequential writing.""" +class PyQueryBuilder: + def __init__(self) -> None: ... + def register(self, table_name: str, delta_table: RawDeltaTable) -> None: ... + def execute(self, sql: str) -> List[pyarrow.RecordBatch]: ... + class DeltaDataChecker: def __init__(self, invariants: List[Tuple[str, str]]) -> None: ... def check_batch(self, batch: pyarrow.RecordBatch) -> None: ... diff --git a/python/deltalake/query.py b/python/deltalake/query.py new file mode 100644 index 0000000000..b3e4100d8c --- /dev/null +++ b/python/deltalake/query.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import warnings +from typing import List + +import pyarrow + +from deltalake._internal import PyQueryBuilder +from deltalake.table import DeltaTable +from deltalake.warnings import ExperimentalWarning + + +class QueryBuilder: + """ + QueryBuilder is an experimental API which exposes Apache DataFusion SQL to Python users of the deltalake library. + + This API is subject to change. + + >>> qb = QueryBuilder() + """ + + def __init__(self) -> None: + warnings.warn( + "QueryBuilder is experimental and subject to change", + category=ExperimentalWarning, + ) + self._query_builder = PyQueryBuilder() + + def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder: + """ + Add a table to the query builder instance by name. The `table_name` + will be how the referenced `DeltaTable` can be referenced in SQL + queries. + + For example: + + >>> tmp = getfixture('tmp_path') + >>> import pyarrow as pa + >>> from deltalake import DeltaTable, QueryBuilder + >>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())])) + >>> qb = QueryBuilder().register('test', dt) + >>> assert qb is not None + """ + self._query_builder.register( + table_name=table_name, + delta_table=delta_table._table, + ) + return self + + def execute(self, sql: str) -> List[pyarrow.RecordBatch]: + """ + Execute the query and return a list of record batches + + For example: + + >>> tmp = getfixture('tmp_path') + >>> import pyarrow as pa + >>> from deltalake import DeltaTable, QueryBuilder + >>> dt = DeltaTable.create(table_uri=tmp, schema=pa.schema([pa.field('name', pa.string())])) + >>> qb = QueryBuilder().register('test', dt) + >>> results = qb.execute('SELECT * FROM test') + >>> assert results is not None + """ + return self._query_builder.execute(sql) diff --git a/python/deltalake/warnings.py b/python/deltalake/warnings.py new file mode 100644 index 0000000000..83c5d34bcd --- /dev/null +++ b/python/deltalake/warnings.py @@ -0,0 +1,2 @@ +class ExperimentalWarning(Warning): + pass diff --git a/python/src/error.rs b/python/src/error.rs index a54b1e60b4..b1d22fc7ca 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -1,4 +1,5 @@ use arrow_schema::ArrowError; +use deltalake::datafusion::error::DataFusionError; use deltalake::protocol::ProtocolError; use deltalake::{errors::DeltaTableError, ObjectStoreError}; use pyo3::exceptions::{ @@ -79,6 +80,10 @@ fn checkpoint_to_py(err: ProtocolError) -> PyErr { } } +fn datafusion_to_py(err: DataFusionError) -> PyErr { + DeltaError::new_err(err.to_string()) +} + #[derive(thiserror::Error, Debug)] pub enum PythonError { #[error("Error in delta table")] @@ -89,6 +94,8 @@ pub enum PythonError { Arrow(#[from] ArrowError), #[error("Error in checkpoint")] Protocol(#[from] ProtocolError), + #[error("Error in data fusion")] + DataFusion(#[from] DataFusionError), } impl From for pyo3::PyErr { @@ -98,6 +105,7 @@ impl From for pyo3::PyErr { PythonError::ObjectStore(err) => object_store_to_py(err), PythonError::Arrow(err) => arrow_to_py(err), PythonError::Protocol(err) => checkpoint_to_py(err), + PythonError::DataFusion(err) => datafusion_to_py(err), } } } diff --git a/python/src/filesystem.rs b/python/src/filesystem.rs index 116b1b0cf1..ee5261ab09 100644 --- a/python/src/filesystem.rs +++ b/python/src/filesystem.rs @@ -13,7 +13,7 @@ use std::sync::Arc; const DEFAULT_MAX_BUFFER_SIZE: usize = 5 * 1024 * 1024; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub(crate) struct FsConfig { pub(crate) root_url: String, pub(crate) options: HashMap, diff --git a/python/src/lib.rs b/python/src/lib.rs index 8da2544b3f..c4a4d80b78 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -2,6 +2,7 @@ mod error; mod features; mod filesystem; mod merge; +mod query; mod schema; mod utils; @@ -20,12 +21,18 @@ use delta_kernel::expressions::Scalar; use delta_kernel::schema::StructField; use deltalake::arrow::compute::concat_batches; use deltalake::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; +use deltalake::arrow::pyarrow::ToPyArrow; use deltalake::arrow::record_batch::{RecordBatch, RecordBatchIterator}; use deltalake::arrow::{self, datatypes::Schema as ArrowSchema}; use deltalake::checkpoints::{cleanup_metadata, create_checkpoint}; +use deltalake::datafusion::datasource::provider_as_source; +use deltalake::datafusion::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use deltalake::datafusion::physical_plan::ExecutionPlan; -use deltalake::datafusion::prelude::SessionContext; -use deltalake::delta_datafusion::DeltaDataChecker; +use deltalake::datafusion::prelude::{DataFrame, SessionContext}; +use deltalake::delta_datafusion::{ + DataFusionMixins, DeltaDataChecker, DeltaScanConfigBuilder, DeltaSessionConfig, + DeltaTableProvider, +}; use deltalake::errors::DeltaTableError; use deltalake::kernel::{ scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType, Transaction, @@ -69,6 +76,7 @@ use crate::error::PythonError; use crate::features::TableFeatures; use crate::filesystem::FsConfig; use crate::merge::PyMergeBuilder; +use crate::query::PyQueryBuilder; use crate::schema::{schema_to_pyobject, Field}; use crate::utils::rt; @@ -2095,6 +2103,7 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { )?)?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/python/src/query.rs b/python/src/query.rs new file mode 100644 index 0000000000..55889c567f --- /dev/null +++ b/python/src/query.rs @@ -0,0 +1,73 @@ +use std::sync::Arc; + +use deltalake::{ + arrow::pyarrow::ToPyArrow, + datafusion::prelude::SessionContext, + delta_datafusion::{DeltaScanConfigBuilder, DeltaSessionConfig, DeltaTableProvider}, +}; +use pyo3::prelude::*; + +use crate::{error::PythonError, utils::rt, RawDeltaTable}; + +/// PyQueryBuilder supports the _experimental_ `QueryBuilder` Pythoh interface which allows users +/// to take advantage of the [Apache DataFusion](https://datafusion.apache.org) engine already +/// present in the Python package. +#[pyclass(module = "deltalake._internal")] +#[derive(Default)] +pub(crate) struct PyQueryBuilder { + /// DataFusion [SessionContext] to hold mappings of registered tables + ctx: SessionContext, +} + +#[pymethods] +impl PyQueryBuilder { + #[new] + pub fn new() -> Self { + let config = DeltaSessionConfig::default().into(); + let ctx = SessionContext::new_with_config(config); + + PyQueryBuilder { ctx } + } + + /// Register the given [RawDeltaTable] into the [SessionContext] using the provided + /// `table_name` + /// + /// Once called, the provided `delta_table` will be referencable in SQL queries so long as + /// another table of the same name is not registered over it. + pub fn register(&self, table_name: &str, delta_table: &RawDeltaTable) -> PyResult<()> { + let snapshot = delta_table._table.snapshot().map_err(PythonError::from)?; + let log_store = delta_table._table.log_store(); + + let scan_config = DeltaScanConfigBuilder::default() + .build(snapshot) + .map_err(PythonError::from)?; + + let provider = Arc::new( + DeltaTableProvider::try_new(snapshot.clone(), log_store, scan_config) + .map_err(PythonError::from)?, + ); + + self.ctx + .register_table(table_name, provider) + .map_err(PythonError::from)?; + + Ok(()) + } + + /// Execute the given SQL command within the [SessionContext] of this instance + /// + /// **NOTE:** Since this function returns a materialized Python list of `RecordBatch` + /// instances, it may result unexpected memory consumption for queries which return large data + /// sets. + pub fn execute(&self, py: Python, sql: &str) -> PyResult { + let batches = py.allow_threads(|| { + rt().block_on(async { + let df = self.ctx.sql(sql).await?; + df.collect().await + }) + .map_err(PythonError::from) + })?; + + batches.to_pyarrow(py) + } +} diff --git a/python/tests/test_checkpoint.py b/python/tests/test_checkpoint.py index 5961a57b09..309a1f3663 100644 --- a/python/tests/test_checkpoint.py +++ b/python/tests/test_checkpoint.py @@ -483,18 +483,19 @@ def test_checkpoint_with_multiple_writes(tmp_path: pathlib.Path): } ), ) - DeltaTable(tmp_path).create_checkpoint() + dt = DeltaTable(tmp_path) + dt.create_checkpoint() + assert dt.version() == 0 + df = pd.DataFrame( + { + "a": ["a"], + "b": [100], + } + ) + write_deltalake(tmp_path, df, mode="overwrite") dt = DeltaTable(tmp_path) + assert dt.version() == 1 + new_df = dt.to_pandas() print(dt.to_pandas()) - - write_deltalake( - tmp_path, - pd.DataFrame( - { - "a": ["a"], - "b": [100], - } - ), - mode="overwrite", - ) + assert len(new_df) == 1, "We overwrote! there should only be one row" diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 8645be538b..e8416f6e5f 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -6,6 +6,7 @@ import pytest from deltalake import DeltaTable, write_deltalake +from deltalake.exceptions import DeltaProtocolError from deltalake.table import CommitProperties @@ -1080,3 +1081,42 @@ def test_cdc_merge_planning_union_2908(tmp_path): assert last_action["operation"] == "MERGE" assert dt.version() == 1 assert os.path.exists(cdc_path), "_change_data doesn't exist" + + +@pytest.mark.pandas +def test_merge_non_nullable(tmp_path): + import re + + import pandas as pd + + from deltalake.schema import Field, PrimitiveType, Schema + + schema = Schema( + [ + Field("id", PrimitiveType("integer"), nullable=False), + Field("bool", PrimitiveType("boolean"), nullable=False), + ] + ) + + dt = DeltaTable.create(tmp_path, schema=schema) + df = pd.DataFrame( + columns=["id", "bool"], + data=[ + [1, True], + [2, None], + [3, False], + ], + ) + + with pytest.raises( + DeltaProtocolError, + match=re.escape( + 'Invariant violations: ["Non-nullable column violation for bool, found 1 null values"]' + ), + ): + dt.merge( + source=df, + source_alias="s", + target_alias="t", + predicate="s.id = t.id", + ).when_matched_update_all().when_not_matched_insert_all().execute() diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 5ff07ed9e8..30d7f21d7f 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -9,6 +9,7 @@ from deltalake._util import encode_partition_value from deltalake.exceptions import DeltaProtocolError +from deltalake.query import QueryBuilder from deltalake.table import ProtocolVersions from deltalake.writer import write_deltalake @@ -946,3 +947,56 @@ def test_is_deltatable_with_storage_opts(): "DELTA_DYNAMO_TABLE_NAME": "custom_table_name", } assert DeltaTable.is_deltatable(table_path, storage_options=storage_options) + + +def test_read_query_builder(): + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + expected = { + "value": ["4", "5", "6", "7"], + "year": ["2021", "2021", "2021", "2021"], + "month": ["4", "12", "12", "12"], + "day": ["5", "4", "20", "20"], + } + actual = pa.Table.from_batches( + QueryBuilder() + .register("tbl", dt) + .execute("SELECT * FROM tbl WHERE year >= 2021 ORDER BY value") + ).to_pydict() + assert expected == actual + + +def test_read_query_builder_join_multiple_tables(tmp_path): + table_path = "../crates/test/tests/data/delta-0.8.0-date" + dt1 = DeltaTable(table_path) + + write_deltalake( + tmp_path, + pa.table( + { + "date": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-12-31"], + "value": ["a", "b", "c", "d"], + } + ), + ) + dt2 = DeltaTable(tmp_path) + + expected = { + "date": ["2021-01-01", "2021-01-02", "2021-01-03"], + "dayOfYear": [1, 2, 3], + "value": ["a", "b", "c"], + } + actual = pa.Table.from_batches( + QueryBuilder() + .register("tbl1", dt1) + .register("tbl2", dt2) + .execute( + """ + SELECT tbl2.date, tbl1.dayOfYear, tbl2.value + FROM tbl1 + INNER JOIN tbl2 ON tbl1.date = tbl2.date + ORDER BY tbl1.date + """ + ) + ).to_pydict() + assert expected == actual