diff --git a/Cargo.toml b/Cargo.toml index 6168a500fd..4cc6bad1b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,9 @@ debug = true debug = "line-tables-only" [workspace.dependencies] +delta_kernel = { version = "0.1" } +# delta_kernel = { path = "../delta-kernel-rs/kernel" } + # arrow arrow = { version = "51" } arrow-arith = { version = "51" } diff --git a/crates/aws/tests/integration_s3_dynamodb.rs b/crates/aws/tests/integration_s3_dynamodb.rs index eb674c4235..6e030e7bb2 100644 --- a/crates/aws/tests/integration_s3_dynamodb.rs +++ b/crates/aws/tests/integration_s3_dynamodb.rs @@ -390,7 +390,7 @@ async fn prepare_table(context: &IntegrationContext, table_name: &str) -> TestRe // create delta table let table = DeltaOps(table) .create() - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .await?; println!("table created: {table:?}"); Ok(table) diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index a83be3ce37..680d1c1475 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -15,6 +15,8 @@ rust-version.workspace = true features = ["datafusion", "json", "unity-experimental"] [dependencies] +delta_kernel.workspace = true + # arrow arrow = { workspace = true } arrow-arith = { workspace = true } diff --git a/crates/core/src/delta_datafusion/expr.rs b/crates/core/src/delta_datafusion/expr.rs index 868969c571..41e6a84b4f 100644 --- a/crates/core/src/delta_datafusion/expr.rs +++ b/crates/core/src/delta_datafusion/expr.rs @@ -542,7 +542,7 @@ mod test { let table = DeltaOps::new_in_memory() .create() - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .await .unwrap(); assert_eq!(table.version(), 0); diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index fae36d7cbf..9c87411973 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -36,7 +36,6 @@ use arrow::record_batch::RecordBatch; use arrow_array::types::UInt16Type; use arrow_array::{Array, DictionaryArray, StringArray, TypedDictionaryArray}; use arrow_cast::display::array_value_to_string; - use arrow_schema::Field; use async_trait::async_trait; use chrono::{DateTime, TimeZone, Utc}; @@ -78,7 +77,6 @@ use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_sql::planner::ParserOptions; use either::Either; use futures::TryStreamExt; - use itertools::Itertools; use object_store::ObjectMeta; use serde::{Deserialize, Serialize}; @@ -86,7 +84,7 @@ use url::Url; use crate::delta_datafusion::expr::parse_predicate_expression; use crate::errors::{DeltaResult, DeltaTableError}; -use crate::kernel::{Add, DataCheck, EagerSnapshot, Invariant, Snapshot}; +use crate::kernel::{Add, DataCheck, EagerSnapshot, Invariant, Snapshot, StructTypeExt}; use crate::logstore::LogStoreRef; use crate::table::builder::ensure_table_uri; use crate::table::state::DeltaTableState; @@ -202,13 +200,11 @@ fn _arrow_schema(snapshot: &Snapshot, wrap_partitions: bool) -> DeltaResult = Result; #[allow(missing_docs)] #[derive(thiserror::Error, Debug)] pub enum DeltaTableError { + #[error("Kernel error: {0}")] + KernelError(#[from] delta_kernel::error::Error), + #[error("Delta protocol violation: {source}")] Protocol { source: ProtocolError }, diff --git a/crates/core/src/kernel/arrow/mod.rs b/crates/core/src/kernel/arrow/mod.rs index 648ad16bbc..ef30002af9 100644 --- a/crates/core/src/kernel/arrow/mod.rs +++ b/crates/core/src/kernel/arrow/mod.rs @@ -3,16 +3,11 @@ use std::sync::Arc; use arrow_schema::{ - ArrowError, DataType as ArrowDataType, Field as ArrowField, FieldRef as ArrowFieldRef, - Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, TimeUnit, + DataType as ArrowDataType, Field as ArrowField, FieldRef as ArrowFieldRef, + Schema as ArrowSchema, SchemaRef as ArrowSchemaRef, }; use lazy_static::lazy_static; -use super::{ - ActionType, ArrayType, DataType, MapType, PrimitiveType, StructField, StructType, - DECIMAL_MAX_PRECISION, DECIMAL_MAX_SCALE, -}; - pub(crate) mod extract; pub(crate) mod json; @@ -21,258 +16,6 @@ const MAP_KEY_DEFAULT: &str = "key"; const MAP_VALUE_DEFAULT: &str = "value"; const LIST_ROOT_DEFAULT: &str = "item"; -impl TryFrom for ArrowField { - type Error = ArrowError; - - fn try_from(value: ActionType) -> Result { - value.schema_field().try_into() - } -} - -impl TryFrom<&StructType> for ArrowSchema { - type Error = ArrowError; - - fn try_from(s: &StructType) -> Result { - let fields = s - .fields() - .iter() - .map(TryInto::try_into) - .collect::, ArrowError>>()?; - - Ok(ArrowSchema::new(fields)) - } -} - -impl TryFrom<&StructField> for ArrowField { - type Error = ArrowError; - - fn try_from(f: &StructField) -> Result { - let metadata = f - .metadata() - .iter() - .map(|(key, val)| Ok((key.clone(), serde_json::to_string(val)?))) - .collect::>() - .map_err(|err| ArrowError::JsonError(err.to_string()))?; - - let field = ArrowField::new( - f.name(), - ArrowDataType::try_from(f.data_type())?, - f.is_nullable(), - ) - .with_metadata(metadata); - - Ok(field) - } -} - -impl TryFrom<&ArrayType> for ArrowField { - type Error = ArrowError; - fn try_from(a: &ArrayType) -> Result { - Ok(ArrowField::new( - LIST_ROOT_DEFAULT, - ArrowDataType::try_from(a.element_type())?, - // TODO check how to handle nullability - a.contains_null(), - )) - } -} - -impl TryFrom<&MapType> for ArrowField { - type Error = ArrowError; - - fn try_from(a: &MapType) -> Result { - Ok(ArrowField::new( - MAP_ROOT_DEFAULT, - ArrowDataType::Struct( - vec![ - ArrowField::new( - MAP_KEY_DEFAULT, - ArrowDataType::try_from(a.key_type())?, - false, - ), - ArrowField::new( - MAP_VALUE_DEFAULT, - ArrowDataType::try_from(a.value_type())?, - a.value_contains_null(), - ), - ] - .into(), - ), - // always non-null - false, - )) - } -} - -impl TryFrom<&DataType> for ArrowDataType { - type Error = ArrowError; - - fn try_from(t: &DataType) -> Result { - match t { - DataType::Primitive(p) => { - match p { - PrimitiveType::String => Ok(ArrowDataType::Utf8), - PrimitiveType::Long => Ok(ArrowDataType::Int64), // undocumented type - PrimitiveType::Integer => Ok(ArrowDataType::Int32), - PrimitiveType::Short => Ok(ArrowDataType::Int16), - PrimitiveType::Byte => Ok(ArrowDataType::Int8), - PrimitiveType::Float => Ok(ArrowDataType::Float32), - PrimitiveType::Double => Ok(ArrowDataType::Float64), - PrimitiveType::Boolean => Ok(ArrowDataType::Boolean), - PrimitiveType::Binary => Ok(ArrowDataType::Binary), - PrimitiveType::Decimal(precision, scale) => { - if precision <= &DECIMAL_MAX_PRECISION && scale <= &DECIMAL_MAX_SCALE { - Ok(ArrowDataType::Decimal128(*precision, *scale)) - } else { - Err(ArrowError::CastError(format!( - "Precision/scale can not be larger than 38 ({},{})", - precision, scale - ))) - } - } - PrimitiveType::Date => { - // A calendar date, represented as a year-month-day triple without a - // timezone. Stored as 4 bytes integer representing days since 1970-01-01 - Ok(ArrowDataType::Date32) - } - PrimitiveType::Timestamp => Ok(ArrowDataType::Timestamp( - TimeUnit::Microsecond, - Some("UTC".into()), - )), - PrimitiveType::TimestampNtz => { - Ok(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)) - } - } - } - DataType::Struct(s) => Ok(ArrowDataType::Struct( - s.fields() - .iter() - .map(TryInto::try_into) - .collect::, ArrowError>>()? - .into(), - )), - DataType::Array(a) => Ok(ArrowDataType::List(Arc::new(a.as_ref().try_into()?))), - DataType::Map(m) => Ok(ArrowDataType::Map(Arc::new(m.as_ref().try_into()?), false)), - } - } -} - -impl TryFrom<&ArrowSchema> for StructType { - type Error = ArrowError; - - fn try_from(arrow_schema: &ArrowSchema) -> Result { - let new_fields: Result, _> = arrow_schema - .fields() - .iter() - .map(|field| field.as_ref().try_into()) - .collect(); - Ok(StructType::new(new_fields?)) - } -} - -impl TryFrom for StructType { - type Error = ArrowError; - - fn try_from(arrow_schema: ArrowSchemaRef) -> Result { - arrow_schema.as_ref().try_into() - } -} - -impl TryFrom<&ArrowField> for StructField { - type Error = ArrowError; - - fn try_from(arrow_field: &ArrowField) -> Result { - Ok(StructField::new( - arrow_field.name().clone(), - DataType::try_from(arrow_field.data_type())?, - arrow_field.is_nullable(), - ) - .with_metadata(arrow_field.metadata().iter().map(|(k, v)| (k.clone(), v)))) - } -} - -impl TryFrom<&ArrowDataType> for DataType { - type Error = ArrowError; - - fn try_from(arrow_datatype: &ArrowDataType) -> Result { - match arrow_datatype { - ArrowDataType::Utf8 => Ok(DataType::Primitive(PrimitiveType::String)), - ArrowDataType::LargeUtf8 => Ok(DataType::Primitive(PrimitiveType::String)), - ArrowDataType::Int64 => Ok(DataType::Primitive(PrimitiveType::Long)), // undocumented type - ArrowDataType::Int32 => Ok(DataType::Primitive(PrimitiveType::Integer)), - ArrowDataType::Int16 => Ok(DataType::Primitive(PrimitiveType::Short)), - ArrowDataType::Int8 => Ok(DataType::Primitive(PrimitiveType::Byte)), - ArrowDataType::UInt64 => Ok(DataType::Primitive(PrimitiveType::Long)), // undocumented type - ArrowDataType::UInt32 => Ok(DataType::Primitive(PrimitiveType::Integer)), - ArrowDataType::UInt16 => Ok(DataType::Primitive(PrimitiveType::Short)), - ArrowDataType::UInt8 => Ok(DataType::Primitive(PrimitiveType::Byte)), - ArrowDataType::Float32 => Ok(DataType::Primitive(PrimitiveType::Float)), - ArrowDataType::Float64 => Ok(DataType::Primitive(PrimitiveType::Double)), - ArrowDataType::Boolean => Ok(DataType::Primitive(PrimitiveType::Boolean)), - ArrowDataType::Binary => Ok(DataType::Primitive(PrimitiveType::Binary)), - ArrowDataType::FixedSizeBinary(_) => Ok(DataType::Primitive(PrimitiveType::Binary)), - ArrowDataType::LargeBinary => Ok(DataType::Primitive(PrimitiveType::Binary)), - ArrowDataType::Decimal128(p, s) => { - Ok(DataType::Primitive(PrimitiveType::Decimal(*p, *s))) - } - ArrowDataType::Decimal256(p, s) => DataType::decimal(*p, *s).map_err(|_| { - ArrowError::SchemaError(format!( - "Invalid data type for Delta Lake: decimal({},{})", - p, s - )) - }), - ArrowDataType::Date32 => Ok(DataType::Primitive(PrimitiveType::Date)), - ArrowDataType::Date64 => Ok(DataType::Primitive(PrimitiveType::Date)), - ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => { - Ok(DataType::Primitive(PrimitiveType::TimestampNtz)) - } - ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz)) - if tz.eq_ignore_ascii_case("utc") => - { - Ok(DataType::Primitive(PrimitiveType::Timestamp)) - } - ArrowDataType::Struct(fields) => { - let converted_fields: Result, _> = fields - .iter() - .map(|field| field.as_ref().try_into()) - .collect(); - Ok(DataType::Struct(Box::new(StructType::new( - converted_fields?, - )))) - } - ArrowDataType::List(field) => Ok(DataType::Array(Box::new(ArrayType::new( - (*field).data_type().try_into()?, - (*field).is_nullable(), - )))), - ArrowDataType::LargeList(field) => Ok(DataType::Array(Box::new(ArrayType::new( - (*field).data_type().try_into()?, - (*field).is_nullable(), - )))), - ArrowDataType::FixedSizeList(field, _) => Ok(DataType::Array(Box::new( - ArrayType::new((*field).data_type().try_into()?, (*field).is_nullable()), - ))), - ArrowDataType::Map(field, _) => { - if let ArrowDataType::Struct(struct_fields) = field.data_type() { - let key_type = struct_fields[0].data_type().try_into()?; - let value_type = struct_fields[1].data_type().try_into()?; - let value_type_nullable = struct_fields[1].is_nullable(); - Ok(DataType::Map(Box::new(MapType::new( - key_type, - value_type, - value_type_nullable, - )))) - } else { - panic!("DataType::Map should contain a struct field child"); - } - } - ArrowDataType::Dictionary(_, value_type) => Ok(value_type.as_ref().try_into()?), - s => Err(ArrowError::SchemaError(format!( - "Invalid data type for Delta Lake: {s}" - ))), - } - } -} - macro_rules! arrow_map { ($fieldname: ident, null) => { ArrowField::new( @@ -615,16 +358,14 @@ fn null_count_schema_for_fields(dest: &mut Vec, f: &ArrowField) { #[cfg(test)] mod tests { - use arrow::array::ArrayData; - use arrow_array::Array; - use arrow_array::{make_array, ArrayRef, MapArray, StringArray, StructArray}; - use arrow_buffer::{Buffer, ToByteSlice}; - use arrow_schema::Field; - - use super::*; use std::collections::HashMap; use std::sync::Arc; + use arrow_array::{MapArray, RecordBatch}; + use delta_kernel::schema::{DataType, MapType, PrimitiveType, StructField, StructType}; + + use super::*; + #[test] fn delta_log_schema_for_table_test() { // NOTE: We should future proof the checkpoint schema in case action schema changes. @@ -766,108 +507,6 @@ mod tests { } } - #[test] - fn test_arrow_from_delta_decimal_type() { - let precision = 20; - let scale = 2; - let decimal_field = DataType::Primitive(PrimitiveType::Decimal(precision, scale)); - assert_eq!( - >::try_from(&decimal_field).unwrap(), - ArrowDataType::Decimal128(precision, scale) - ); - } - - #[test] - fn test_arrow_from_delta_decimal_type_invalid_precision() { - let precision = 39; - let scale = 2; - assert!(matches!( - >::try_from(&ArrowDataType::Decimal256( - precision, scale - )) - .unwrap_err(), - _ - )); - } - - #[test] - fn test_arrow_from_delta_decimal_type_invalid_scale() { - let precision = 2; - let scale = 39; - assert!(matches!( - >::try_from(&ArrowDataType::Decimal256( - precision, scale - )) - .unwrap_err(), - _ - )); - } - - #[test] - fn test_arrow_from_delta_timestamp_type() { - let timestamp_field = DataType::Primitive(PrimitiveType::Timestamp); - assert_eq!( - >::try_from(×tamp_field).unwrap(), - ArrowDataType::Timestamp(TimeUnit::Microsecond, Some("UTC".to_string().into())) - ); - } - - #[test] - fn test_arrow_from_delta_timestampntz_type() { - let timestamp_field = DataType::Primitive(PrimitiveType::TimestampNtz); - assert_eq!( - >::try_from(×tamp_field).unwrap(), - ArrowDataType::Timestamp(TimeUnit::Microsecond, None) - ); - } - - #[test] - fn test_delta_from_arrow_timestamp_type_no_tz() { - let timestamp_field = ArrowDataType::Timestamp(TimeUnit::Microsecond, None); - assert_eq!( - >::try_from(×tamp_field).unwrap(), - DataType::Primitive(PrimitiveType::TimestampNtz) - ); - } - - #[test] - fn test_delta_from_arrow_timestamp_type_with_tz() { - let timestamp_field = - ArrowDataType::Timestamp(TimeUnit::Microsecond, Some("UTC".to_string().into())); - assert_eq!( - >::try_from(×tamp_field).unwrap(), - DataType::Primitive(PrimitiveType::Timestamp) - ); - } - - #[test] - fn test_delta_from_arrow_map_type() { - let arrow_map = ArrowDataType::Map( - Arc::new(ArrowField::new( - "entries", - ArrowDataType::Struct( - vec![ - ArrowField::new("key", ArrowDataType::Int8, false), - ArrowField::new("value", ArrowDataType::Binary, true), - ] - .into(), - ), - false, - )), - false, - ); - let converted_map: DataType = (&arrow_map).try_into().unwrap(); - - assert_eq!( - converted_map, - DataType::Map(Box::new(MapType::new( - DataType::Primitive(PrimitiveType::Byte), - DataType::Primitive(PrimitiveType::Binary), - true, - ))) - ); - } - #[test] fn test_record_batch_from_map_type() { let keys = vec!["0", "1", "5", "6", "7"]; @@ -881,47 +520,7 @@ mod tests { let entry_offsets = vec![0u32, 1, 1, 4, 5, 5]; let num_rows = keys.len(); - // Copied the function `new_from_string` with the patched code from https://github.com/apache/arrow-rs/pull/4808 - // This should be reverted back [`MapArray::new_from_strings`] once arrow is upgraded in this project. - fn new_from_strings<'a>( - keys: impl Iterator, - values: &dyn Array, - entry_offsets: &[u32], - ) -> Result { - let entry_offsets_buffer = Buffer::from(entry_offsets.to_byte_slice()); - let keys_data = StringArray::from_iter_values(keys); - - let keys_field = Arc::new(Field::new("key", ArrowDataType::Utf8, false)); - let values_field = Arc::new(Field::new( - "value", - values.data_type().clone(), - values.null_count() > 0, - )); - - let entry_struct = StructArray::from(vec![ - (keys_field, Arc::new(keys_data) as ArrayRef), - (values_field, make_array(values.to_data())), - ]); - - let map_data_type = ArrowDataType::Map( - Arc::new(Field::new( - "entries", - entry_struct.data_type().clone(), - false, - )), - false, - ); - - let map_data = ArrayData::builder(map_data_type) - .len(entry_offsets.len() - 1) - .add_buffer(entry_offsets_buffer) - .add_child_data(entry_struct.into_data()) - .build()?; - - Ok(MapArray::from(map_data)) - } - - let map_array = new_from_strings( + let map_array = MapArray::new_from_strings( keys.into_iter(), &arrow::array::BinaryArray::from(values), entry_offsets.as_slice(), @@ -942,9 +541,8 @@ mod tests { ])) .expect("Could not get schema"); - let record_batch = - arrow::record_batch::RecordBatch::try_new(Arc::new(schema), vec![Arc::new(map_array)]) - .expect("Failed to create RecordBatch"); + let record_batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(map_array)]) + .expect("Failed to create RecordBatch"); assert_eq!(record_batch.num_columns(), 1); assert_eq!(record_batch.num_rows(), num_rows); diff --git a/crates/core/src/kernel/expressions/eval.rs b/crates/core/src/kernel/expressions/eval.rs deleted file mode 100644 index cb6beea3ad..0000000000 --- a/crates/core/src/kernel/expressions/eval.rs +++ /dev/null @@ -1,384 +0,0 @@ -//! Default Expression handler. -//! -//! Expression handling based on arrow-rs compute kernels. - -use std::sync::Arc; - -use arrow_arith::boolean::{and, is_null, not, or}; -use arrow_arith::numeric::{add, div, mul, sub}; -use arrow_array::{ - Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Datum, Decimal128Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, RecordBatch, StringArray, - StructArray, TimestampMicrosecondArray, -}; -use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; -use arrow_schema::{ArrowError, Field as ArrowField, Schema as ArrowSchema}; -use arrow_select::nullif::nullif; - -use crate::kernel::arrow::extract::extract_column; -use crate::kernel::error::{DeltaResult, Error}; -use crate::kernel::expressions::{scalars::Scalar, Expression}; -use crate::kernel::expressions::{BinaryOperator, UnaryOperator}; -use crate::kernel::{DataType, PrimitiveType, VariadicOperator}; - -fn downcast_to_bool(arr: &dyn Array) -> DeltaResult<&BooleanArray> { - arr.as_any() - .downcast_ref::() - .ok_or(Error::Generic("expected boolean array".to_string())) -} - -fn wrap_comparison_result(arr: BooleanArray) -> ArrayRef { - Arc::new(arr) as Arc -} - -// TODO leverage scalars / Datum - -impl Scalar { - /// Convert scalar to arrow array. - pub fn to_array(&self, num_rows: usize) -> DeltaResult { - use Scalar::*; - let arr: ArrayRef = match self { - Integer(val) => Arc::new(Int32Array::from_value(*val, num_rows)), - Long(val) => Arc::new(Int64Array::from_value(*val, num_rows)), - Short(val) => Arc::new(Int16Array::from_value(*val, num_rows)), - Byte(val) => Arc::new(Int8Array::from_value(*val, num_rows)), - Float(val) => Arc::new(Float32Array::from_value(*val, num_rows)), - Double(val) => Arc::new(Float64Array::from_value(*val, num_rows)), - String(val) => Arc::new(StringArray::from(vec![val.clone(); num_rows])), - Boolean(val) => Arc::new(BooleanArray::from(vec![*val; num_rows])), - Timestamp(val) => { - Arc::new(TimestampMicrosecondArray::from_value(*val, num_rows).with_timezone("UTC")) - } - TimestampNtz(val) => Arc::new(TimestampMicrosecondArray::from_value(*val, num_rows)), - Date(val) => Arc::new(Date32Array::from_value(*val, num_rows)), - Binary(val) => Arc::new(BinaryArray::from(vec![val.as_slice(); num_rows])), - Decimal(val, precision, scale) => Arc::new( - Decimal128Array::from_value(*val, num_rows) - .with_precision_and_scale(*precision, *scale)?, - ), - Null(data_type) => match data_type { - DataType::Primitive(primitive) => match primitive { - PrimitiveType::Byte => Arc::new(Int8Array::new_null(num_rows)), - PrimitiveType::Short => Arc::new(Int16Array::new_null(num_rows)), - PrimitiveType::Integer => Arc::new(Int32Array::new_null(num_rows)), - PrimitiveType::Long => Arc::new(Int64Array::new_null(num_rows)), - PrimitiveType::Float => Arc::new(Float32Array::new_null(num_rows)), - PrimitiveType::Double => Arc::new(Float64Array::new_null(num_rows)), - PrimitiveType::String => Arc::new(StringArray::new_null(num_rows)), - PrimitiveType::Boolean => Arc::new(BooleanArray::new_null(num_rows)), - PrimitiveType::Timestamp => { - Arc::new(TimestampMicrosecondArray::new_null(num_rows).with_timezone("UTC")) - } - PrimitiveType::TimestampNtz => { - Arc::new(TimestampMicrosecondArray::new_null(num_rows)) - } - PrimitiveType::Date => Arc::new(Date32Array::new_null(num_rows)), - PrimitiveType::Binary => Arc::new(BinaryArray::new_null(num_rows)), - PrimitiveType::Decimal(precision, scale) => Arc::new( - Decimal128Array::new_null(num_rows) - .with_precision_and_scale(*precision, *scale) - .unwrap(), - ), - }, - DataType::Array(_) => unimplemented!(), - DataType::Map { .. } => unimplemented!(), - DataType::Struct { .. } => unimplemented!(), - }, - Struct(values, fields) => { - let mut columns = Vec::with_capacity(values.len()); - for val in values { - columns.push(val.to_array(num_rows)?); - } - Arc::new(StructArray::try_new( - fields - .iter() - .map(TryInto::::try_into) - .collect::, _>>()? - .into(), - columns, - None, - )?) - } - }; - Ok(arr) - } -} - -/// evaluate expression -pub(crate) fn evaluate_expression( - expression: &Expression, - batch: &RecordBatch, - result_type: Option<&DataType>, -) -> DeltaResult { - use BinaryOperator::*; - use Expression::*; - - match (expression, result_type) { - (Literal(scalar), _) => Ok(scalar.to_array(batch.num_rows())?), - (Column(name), _) => { - if name.contains('.') { - let mut path = name.split('.'); - // Safety: we know that the first path step exists, because we checked for '.' - let arr = extract_column(batch, path.next().unwrap(), &mut path).cloned()?; - // NOTE: need to assign first so that rust can figure out lifetimes - Ok(arr) - } else { - batch - .column_by_name(name) - .ok_or(Error::MissingColumn(name.clone())) - .cloned() - } - } - (Struct(fields), Some(DataType::Struct(schema))) => { - let output_schema: ArrowSchema = schema.as_ref().try_into()?; - let mut columns = Vec::with_capacity(fields.len()); - for (expr, field) in fields.iter().zip(schema.fields()) { - columns.push(evaluate_expression(expr, batch, Some(field.data_type()))?); - } - Ok(Arc::new(StructArray::try_new( - output_schema.fields().clone(), - columns, - None, - )?)) - } - (Struct(_), _) => Err(Error::Generic( - "Data type is required to evaluate struct expressions".to_string(), - )), - (UnaryOperation { op, expr }, _) => { - let arr = evaluate_expression(expr.as_ref(), batch, None)?; - Ok(match op { - UnaryOperator::Not => Arc::new(not(downcast_to_bool(&arr)?)?), - UnaryOperator::IsNull => Arc::new(is_null(&arr)?), - }) - } - (BinaryOperation { op, left, right }, _) => { - let left_arr = evaluate_expression(left.as_ref(), batch, None)?; - let right_arr = evaluate_expression(right.as_ref(), batch, None)?; - - type Operation = fn(&dyn Datum, &dyn Datum) -> Result, ArrowError>; - let eval: Operation = match op { - Plus => add, - Minus => sub, - Multiply => mul, - Divide => div, - LessThan => |l, r| lt(l, r).map(wrap_comparison_result), - LessThanOrEqual => |l, r| lt_eq(l, r).map(wrap_comparison_result), - GreaterThan => |l, r| gt(l, r).map(wrap_comparison_result), - GreaterThanOrEqual => |l, r| gt_eq(l, r).map(wrap_comparison_result), - Equal => |l, r| eq(l, r).map(wrap_comparison_result), - NotEqual => |l, r| neq(l, r).map(wrap_comparison_result), - }; - - eval(&left_arr, &right_arr).map_err(|err| Error::GenericError { - source: Box::new(err), - }) - } - (VariadicOperation { op, exprs }, _) => { - let reducer = match op { - VariadicOperator::And => and, - VariadicOperator::Or => or, - }; - exprs - .iter() - .map(|expr| evaluate_expression(expr, batch, Some(&DataType::BOOLEAN))) - .reduce(|l, r| { - Ok(reducer(downcast_to_bool(&l?)?, downcast_to_bool(&r?)?) - .map(wrap_comparison_result)?) - }) - .transpose()? - .ok_or(Error::Generic("empty expression".to_string())) - } - (NullIf { expr, if_expr }, _) => { - let expr_arr = evaluate_expression(expr.as_ref(), batch, None)?; - let if_expr_arr = - evaluate_expression(if_expr.as_ref(), batch, Some(&DataType::BOOLEAN))?; - let if_expr_arr = downcast_to_bool(&if_expr_arr)?; - Ok(nullif(&expr_arr, if_expr_arr)?) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow_array::Int32Array; - use arrow_schema::{DataType, Field, Fields, Schema}; - use std::ops::{Add, Div, Mul, Sub}; - - #[test] - fn test_extract_column() { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let values = Int32Array::from(vec![1, 2, 3]); - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values.clone())]).unwrap(); - let column = Expression::Column("a".to_string()); - - let results = evaluate_expression(&column, &batch, None).unwrap(); - assert_eq!(results.as_ref(), &values); - - let schema = Schema::new(vec![Field::new( - "b", - DataType::Struct(Fields::from(vec![Field::new("a", DataType::Int32, false)])), - false, - )]); - - let struct_values: ArrayRef = Arc::new(values.clone()); - let struct_array = StructArray::from(vec![( - Arc::new(Field::new("a", DataType::Int32, false)), - struct_values, - )]); - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(struct_array.clone())], - ) - .unwrap(); - let column = Expression::Column("b.a".to_string()); - let results = evaluate_expression(&column, &batch, None).unwrap(); - assert_eq!(results.as_ref(), &values); - } - - #[test] - fn test_binary_op_scalar() { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let values = Int32Array::from(vec![1, 2, 3]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap(); - let column = Expression::Column("a".to_string()); - - let expression = Box::new(column.clone().add(Expression::Literal(Scalar::Integer(1)))); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(Int32Array::from(vec![2, 3, 4])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new(column.clone().sub(Expression::Literal(Scalar::Integer(1)))); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(Int32Array::from(vec![0, 1, 2])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new(column.clone().mul(Expression::Literal(Scalar::Integer(2)))); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(Int32Array::from(vec![2, 4, 6])); - assert_eq!(results.as_ref(), expected.as_ref()); - - // TODO handle type casting - let expression = Box::new(column.div(Expression::Literal(Scalar::Integer(1)))); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(Int32Array::from(vec![1, 2, 3])); - assert_eq!(results.as_ref(), expected.as_ref()) - } - - #[test] - fn test_binary_op() { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ]); - let values = Int32Array::from(vec![1, 2, 3]); - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(values.clone()), Arc::new(values)], - ) - .unwrap(); - let column_a = Expression::Column("a".to_string()); - let column_b = Expression::Column("b".to_string()); - - let expression = Box::new(column_a.clone().add(column_b.clone())); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(Int32Array::from(vec![2, 4, 6])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new(column_a.clone().sub(column_b.clone())); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(Int32Array::from(vec![0, 0, 0])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new(column_a.clone().mul(column_b)); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(Int32Array::from(vec![1, 4, 9])); - assert_eq!(results.as_ref(), expected.as_ref()); - } - - #[test] - fn test_binary_cmp() { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let values = Int32Array::from(vec![1, 2, 3]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap(); - let column = Expression::Column("a".to_string()); - let lit = Expression::Literal(Scalar::Integer(2)); - - let expression = Box::new(column.clone().lt(lit.clone())); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, false, false])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new(column.clone().lt_eq(lit.clone())); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, true, false])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new(column.clone().gt(lit.clone())); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![false, false, true])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new(column.clone().gt_eq(lit.clone())); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![false, true, true])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new(column.clone().eq(lit.clone())); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![false, true, false])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new(column.clone().ne(lit.clone())); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, false, true])); - assert_eq!(results.as_ref(), expected.as_ref()); - } - - #[test] - fn test_logical() { - let schema = Schema::new(vec![ - Field::new("a", DataType::Boolean, false), - Field::new("b", DataType::Boolean, false), - ]); - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(BooleanArray::from(vec![true, false])), - Arc::new(BooleanArray::from(vec![false, true])), - ], - ) - .unwrap(); - let column_a = Expression::Column("a".to_string()); - let column_b = Expression::Column("b".to_string()); - - let expression = Box::new(column_a.clone().and(column_b.clone())); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![false, false])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new( - column_a - .clone() - .and(Expression::literal(Scalar::Boolean(true))), - ); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, false])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new(column_a.clone().or(column_b)); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, true])); - assert_eq!(results.as_ref(), expected.as_ref()); - - let expression = Box::new( - column_a - .clone() - .or(Expression::literal(Scalar::Boolean(false))), - ); - let results = evaluate_expression(&expression, &batch, None).unwrap(); - let expected = Arc::new(BooleanArray::from(vec![true, false])); - assert_eq!(results.as_ref(), expected.as_ref()); - } -} diff --git a/crates/core/src/kernel/expressions/mod.rs b/crates/core/src/kernel/expressions/mod.rs deleted file mode 100644 index dd8aab51de..0000000000 --- a/crates/core/src/kernel/expressions/mod.rs +++ /dev/null @@ -1,478 +0,0 @@ -//! expressions. - -use std::collections::HashSet; -use std::fmt::{Display, Formatter}; -use std::sync::Arc; - -use arrow_array::{ArrayRef, RecordBatch}; -use arrow_schema::Schema as ArrowSchema; -use itertools::Itertools; - -use self::eval::evaluate_expression; -use super::{DataType, DeltaResult, SchemaRef}; - -pub use self::scalars::*; - -mod eval; -mod scalars; - -/// Interface for implementing an Expression evaluator. -/// -/// It contains one Expression which can be evaluated on multiple ColumnarBatches. -/// Connectors can implement this interface to optimize the evaluation using the -/// connector specific capabilities. -pub trait ExpressionEvaluator { - /// Evaluate the expression on given ColumnarBatch data. - /// - /// Contains one value for each row of the input. - /// The data type of the output is same as the type output of the expression this evaluator is using. - fn evaluate(&self, batch: &RecordBatch) -> DeltaResult; -} - -/// Provides expression evaluation capability to Delta Kernel. -/// -/// Delta Kernel can use this client to evaluate predicate on partition filters, -/// fill up partition column values and any computation on data using Expressions. -pub trait ExpressionHandler { - /// Create an [`ExpressionEvaluator`] that can evaluate the given [`Expression`] - /// on columnar batches with the given [`Schema`] to produce data of [`DataType`]. - /// - /// # Parameters - /// - /// - `schema`: Schema of the input data. - /// - `expression`: Expression to evaluate. - /// - `output_type`: Expected result data type. - /// - /// [`Schema`]: crate::schema::StructType - /// [`DataType`]: crate::schema::DataType - fn get_evaluator( - &self, - schema: SchemaRef, - expression: Expression, - output_type: DataType, - ) -> Arc; -} - -/// Default implementation of [`ExpressionHandler`] that uses [`evaluate_expression`] -#[derive(Debug)] -pub struct ArrowExpressionHandler {} - -impl ExpressionHandler for ArrowExpressionHandler { - fn get_evaluator( - &self, - schema: SchemaRef, - expression: Expression, - output_type: DataType, - ) -> Arc { - Arc::new(DefaultExpressionEvaluator { - input_schema: schema, - expression: Box::new(expression), - output_type, - }) - } -} - -/// Default implementation of [`ExpressionEvaluator`] that uses [`evaluate_expression`] -#[derive(Debug)] -pub struct DefaultExpressionEvaluator { - input_schema: SchemaRef, - expression: Box, - output_type: DataType, -} - -impl ExpressionEvaluator for DefaultExpressionEvaluator { - fn evaluate(&self, batch: &RecordBatch) -> DeltaResult { - let _input_schema: ArrowSchema = self.input_schema.as_ref().try_into()?; - // TODO: make sure we have matching schemas for validation - // if batch.schema().as_ref() != &input_schema { - // return Err(Error::Generic(format!( - // "input schema does not match batch schema: {:?} != {:?}", - // input_schema, - // batch.schema() - // ))); - // }; - evaluate_expression(&self.expression, batch, Some(&self.output_type)) - } -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -/// A binary operator. -pub enum BinaryOperator { - /// Arithmetic Plus - Plus, - /// Arithmetic Minus - Minus, - /// Arithmetic Multiply - Multiply, - /// Arithmetic Divide - Divide, - /// Comparison Less Than - LessThan, - /// Comparison Less Than Or Equal - LessThanOrEqual, - /// Comparison Greater Than - GreaterThan, - /// Comparison Greater Than Or Equal - GreaterThanOrEqual, - /// Comparison Equal - Equal, - /// Comparison Not Equal - NotEqual, -} - -/// Variadic operators -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum VariadicOperator { - /// AND - And, - /// OR - Or, -} - -impl Display for BinaryOperator { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - // Self::And => write!(f, "AND"), - // Self::Or => write!(f, "OR"), - Self::Plus => write!(f, "+"), - Self::Minus => write!(f, "-"), - Self::Multiply => write!(f, "*"), - Self::Divide => write!(f, "/"), - Self::LessThan => write!(f, "<"), - Self::LessThanOrEqual => write!(f, "<="), - Self::GreaterThan => write!(f, ">"), - Self::GreaterThanOrEqual => write!(f, ">="), - Self::Equal => write!(f, "="), - Self::NotEqual => write!(f, "!="), - } - } -} - -#[derive(Debug, Copy, Clone, PartialEq)] -/// A unary operator. -pub enum UnaryOperator { - /// Unary Not - Not, - /// Unary Is Null - IsNull, -} - -/// A SQL expression. -/// -/// These expressions do not track or validate data types, other than the type -/// of literals. It is up to the expression evaluator to validate the -/// expression against a schema and add appropriate casts as required. -#[derive(Debug, Clone, PartialEq)] -pub enum Expression { - /// A literal value. - Literal(Scalar), - /// A column reference by name. - Column(String), - /// - Struct(Vec), - /// A binary operation. - BinaryOperation { - /// The operator. - op: BinaryOperator, - /// The left-hand side of the operation. - left: Box, - /// The right-hand side of the operation. - right: Box, - }, - /// A unary operation. - UnaryOperation { - /// The operator. - op: UnaryOperator, - /// The expression. - expr: Box, - }, - /// A variadic operation. - VariadicOperation { - /// The operator. - op: VariadicOperator, - /// The expressions. - exprs: Vec, - }, - /// A NULLIF expression. - NullIf { - /// The expression to evaluate. - expr: Box, - /// The expression to compare against. - if_expr: Box, - }, - // TODO: support more expressions, such as IS IN, LIKE, etc. -} - -impl> From for Expression { - fn from(value: T) -> Self { - Self::literal(value) - } -} - -impl Display for Expression { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Literal(l) => write!(f, "{}", l), - Self::Column(name) => write!(f, "Column({})", name), - Self::Struct(exprs) => write!( - f, - "Struct({})", - &exprs.iter().map(|e| format!("{e}")).join(", ") - ), - Self::BinaryOperation { op, left, right } => write!(f, "{} {} {}", left, op, right), - Self::UnaryOperation { op, expr } => match op { - UnaryOperator::Not => write!(f, "NOT {}", expr), - UnaryOperator::IsNull => write!(f, "{} IS NULL", expr), - }, - Self::VariadicOperation { op, exprs } => match op { - VariadicOperator::And => { - write!( - f, - "AND({})", - &exprs.iter().map(|e| format!("{e}")).join(", ") - ) - } - VariadicOperator::Or => { - write!( - f, - "OR({})", - &exprs.iter().map(|e| format!("{e}")).join(", ") - ) - } - }, - Self::NullIf { expr, if_expr } => write!(f, "NULLIF({}, {})", expr, if_expr), - } - } -} - -impl Expression { - /// Returns a set of columns referenced by this expression. - pub fn references(&self) -> HashSet<&str> { - let mut set = HashSet::new(); - - for expr in self.walk() { - if let Self::Column(name) = expr { - set.insert(name.as_str()); - } - } - - set - } - - /// Create an new expression for a column reference - pub fn column(name: impl Into) -> Self { - Self::Column(name.into()) - } - - /// Create a new expression for a literal value - pub fn literal(value: impl Into) -> Self { - Self::Literal(value.into()) - } - - /// Create a new expression for a struct - pub fn struct_expr(exprs: impl IntoIterator) -> Self { - Self::Struct(exprs.into_iter().collect()) - } - - /// Create a new expression for a unary operation - pub fn unary(op: UnaryOperator, expr: impl Into) -> Self { - Self::UnaryOperation { - op, - expr: Box::new(expr.into()), - } - } - - /// Create a new expression for a binary operation - pub fn binary( - op: BinaryOperator, - lhs: impl Into, - rhs: impl Into, - ) -> Self { - Self::BinaryOperation { - op, - left: Box::new(lhs.into()), - right: Box::new(rhs.into()), - } - } - - /// Create a new expression for a variadic operation - pub fn variadic(op: VariadicOperator, other: impl IntoIterator) -> Self { - let mut exprs = other.into_iter().collect::>(); - if exprs.is_empty() { - // TODO this might break if we introduce new variadic operators? - return Self::literal(matches!(op, VariadicOperator::And)); - } - if exprs.len() == 1 { - return exprs.pop().unwrap(); - } - Self::VariadicOperation { op, exprs } - } - - /// Create a new expression `self == other` - pub fn eq(self, other: Self) -> Self { - Self::binary(BinaryOperator::Equal, self, other) - } - - /// Create a new expression `self != other` - pub fn ne(self, other: Self) -> Self { - Self::binary(BinaryOperator::NotEqual, self, other) - } - - /// Create a new expression `self < other` - pub fn lt(self, other: Self) -> Self { - Self::binary(BinaryOperator::LessThan, self, other) - } - - /// Create a new expression `self > other` - pub fn gt(self, other: Self) -> Self { - Self::binary(BinaryOperator::GreaterThan, self, other) - } - - /// Create a new expression `self >= other` - pub fn gt_eq(self, other: Self) -> Self { - Self::binary(BinaryOperator::GreaterThanOrEqual, self, other) - } - - /// Create a new expression `self <= other` - pub fn lt_eq(self, other: Self) -> Self { - Self::binary(BinaryOperator::LessThanOrEqual, self, other) - } - - /// Create a new expression `self AND other` - pub fn and(self, other: Self) -> Self { - self.and_many([other]) - } - - /// Create a new expression `self AND others` - pub fn and_many(self, other: impl IntoIterator) -> Self { - Self::variadic(VariadicOperator::And, std::iter::once(self).chain(other)) - } - - /// Create a new expression `self AND other` - pub fn or(self, other: Self) -> Self { - self.or_many([other]) - } - - /// Create a new expression `self OR other` - pub fn or_many(self, other: impl IntoIterator) -> Self { - Self::variadic(VariadicOperator::Or, std::iter::once(self).chain(other)) - } - - /// Create a new expression `self IS NULL` - pub fn is_null(self) -> Self { - Self::unary(UnaryOperator::IsNull, self) - } - - /// Create a new expression `NULLIF(self, other)` - pub fn null_if(self, other: Self) -> Self { - Self::NullIf { - expr: Box::new(self), - if_expr: Box::new(other), - } - } - - fn walk(&self) -> impl Iterator + '_ { - let mut stack = vec![self]; - std::iter::from_fn(move || { - let expr = stack.pop()?; - match expr { - Self::Literal(_) => {} - Self::Column { .. } => {} - Self::Struct(exprs) => { - stack.extend(exprs.iter()); - } - Self::BinaryOperation { left, right, .. } => { - stack.push(left); - stack.push(right); - } - Self::UnaryOperation { expr, .. } => { - stack.push(expr); - } - Self::VariadicOperation { op, exprs } => match op { - VariadicOperator::And | VariadicOperator::Or => { - stack.extend(exprs.iter()); - } - }, - Self::NullIf { expr, if_expr } => { - stack.push(expr); - stack.push(if_expr); - } - } - Some(expr) - }) - } -} - -impl std::ops::Add for Expression { - type Output = Self; - - fn add(self, rhs: Expression) -> Self::Output { - Self::binary(BinaryOperator::Plus, self, rhs) - } -} - -impl std::ops::Sub for Expression { - type Output = Self; - - fn sub(self, rhs: Expression) -> Self::Output { - Self::binary(BinaryOperator::Minus, self, rhs) - } -} - -impl std::ops::Mul for Expression { - type Output = Self; - - fn mul(self, rhs: Expression) -> Self::Output { - Self::binary(BinaryOperator::Multiply, self, rhs) - } -} - -impl std::ops::Div for Expression { - type Output = Self; - - fn div(self, rhs: Expression) -> Self::Output { - Self::binary(BinaryOperator::Divide, self, rhs) - } -} - -#[cfg(test)] -mod tests { - use super::Expression as Expr; - - #[test] - fn test_expression_format() { - let col_ref = Expr::column("x"); - let cases = [ - (col_ref.clone(), "Column(x)"), - (col_ref.clone().eq(Expr::literal(2)), "Column(x) = 2"), - ( - col_ref - .clone() - .gt_eq(Expr::literal(2)) - .and(col_ref.clone().lt_eq(Expr::literal(10))), - "AND(Column(x) >= 2, Column(x) <= 10)", - ), - ( - col_ref - .clone() - .gt(Expr::literal(2)) - .or(col_ref.clone().lt(Expr::literal(10))), - "OR(Column(x) > 2, Column(x) < 10)", - ), - ( - (col_ref.clone() - Expr::literal(4)).lt(Expr::literal(10)), - "Column(x) - 4 < 10", - ), - ( - (col_ref.clone() + Expr::literal(4)) / Expr::literal(10) * Expr::literal(42), - "Column(x) + 4 / 10 * 42", - ), - (col_ref.eq(Expr::literal("foo")), "Column(x) = 'foo'"), - ]; - - for (expr, expected) in cases { - let result = format!("{}", expr); - assert_eq!(result, expected); - } - } -} diff --git a/crates/core/src/kernel/expressions/scalars.rs b/crates/core/src/kernel/expressions/scalars.rs deleted file mode 100644 index 571c2abf92..0000000000 --- a/crates/core/src/kernel/expressions/scalars.rs +++ /dev/null @@ -1,559 +0,0 @@ -//! Scalar values for use in expressions. - -use std::cmp::Ordering; -use std::fmt::{Display, Formatter}; - -use arrow_array::Array; -use arrow_schema::TimeUnit; -use chrono::{DateTime, NaiveDate, NaiveDateTime, TimeZone, Utc}; -use object_store::path::Path; - -use crate::kernel::{DataType, Error, PrimitiveType, StructField}; -use crate::NULL_PARTITION_VALUE_DATA_PATH; - -/// A single value, which can be null. Used for representing literal values -/// in [Expressions][crate::expressions::Expression]. -#[derive(Debug, Clone, PartialEq)] -pub enum Scalar { - /// 32bit integer - Integer(i32), - /// 64bit integer - Long(i64), - /// 16bit integer - Short(i16), - /// 8bit integer - Byte(i8), - /// 32bit floating point - Float(f32), - /// 64bit floating point - Double(f64), - /// utf-8 encoded string. - String(String), - /// true or false value - Boolean(bool), - /// Microsecond precision timestamp, adjusted to UTC. - Timestamp(i64), - /// Microsecond precision timestamp, with no timezone. - TimestampNtz(i64), - /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 - Date(i32), - /// Binary data - Binary(Vec), - /// Decimal value - Decimal(i128, u8, i8), - /// Null value with a given data type. - Null(DataType), - /// Struct value - Struct(Vec, Vec), -} - -impl Scalar { - /// Returns the data type of this scalar. - pub fn data_type(&self) -> DataType { - match self { - Self::Integer(_) => DataType::Primitive(PrimitiveType::Integer), - Self::Long(_) => DataType::Primitive(PrimitiveType::Long), - Self::Short(_) => DataType::Primitive(PrimitiveType::Short), - Self::Byte(_) => DataType::Primitive(PrimitiveType::Byte), - Self::Float(_) => DataType::Primitive(PrimitiveType::Float), - Self::Double(_) => DataType::Primitive(PrimitiveType::Double), - Self::String(_) => DataType::Primitive(PrimitiveType::String), - Self::Boolean(_) => DataType::Primitive(PrimitiveType::Boolean), - Self::Timestamp(_) => DataType::Primitive(PrimitiveType::Timestamp), - Self::TimestampNtz(_) => DataType::Primitive(PrimitiveType::TimestampNtz), - Self::Date(_) => DataType::Primitive(PrimitiveType::Date), - Self::Binary(_) => DataType::Primitive(PrimitiveType::Binary), - // Unwrapping should be safe, since the scalar should never have values that are unsupported - Self::Decimal(_, precision, scale) => DataType::decimal(*precision, *scale).unwrap(), - Self::Null(data_type) => data_type.clone(), - Self::Struct(_, fields) => DataType::struct_type(fields.clone()), - } - } - - /// Returns true if this scalar is null. - pub fn is_null(&self) -> bool { - matches!(self, Self::Null(_)) - } - - /// Serializes this scalar as a string. - pub fn serialize(&self) -> String { - match self { - Self::String(s) => s.to_owned(), - Self::Byte(b) => b.to_string(), - Self::Short(s) => s.to_string(), - Self::Integer(i) => i.to_string(), - Self::Long(l) => l.to_string(), - Self::Float(f) => f.to_string(), - Self::Double(d) => d.to_string(), - Self::Boolean(b) => { - if *b { - "true".to_string() - } else { - "false".to_string() - } - } - Self::TimestampNtz(ts) | Self::Timestamp(ts) => { - let ts = Utc.timestamp_micros(*ts).single().unwrap(); - ts.format("%Y-%m-%d %H:%M:%S%.6f").to_string() - } - Self::Date(days) => { - let date = DateTime::from_timestamp(*days as i64 * 24 * 3600, 0).unwrap(); - date.format("%Y-%m-%d").to_string() - } - Self::Decimal(value, _, scale) => match scale.cmp(&0) { - Ordering::Equal => value.to_string(), - Ordering::Greater => { - let scalar_multiple = 10_i128.pow(*scale as u32); - let mut s = String::new(); - s.push_str((value / scalar_multiple).to_string().as_str()); - s.push('.'); - s.push_str(&format!( - "{:0>scale$}", - value % scalar_multiple, - scale = *scale as usize - )); - s - } - Ordering::Less => { - let mut s = value.to_string(); - for _ in 0..(scale.abs()) { - s.push('0'); - } - s - } - }, - Self::Binary(val) => create_escaped_binary_string(val.as_slice()), - Self::Null(_) => "null".to_string(), - Self::Struct(_, _) => todo!("serializing struct values is not yet supported"), - } - } - - /// Serializes this scalar as a string for use in hive partition file names. - pub fn serialize_encoded(&self) -> String { - if self.is_null() { - return NULL_PARTITION_VALUE_DATA_PATH.to_string(); - } - Path::from(self.serialize()).to_string() - } - - /// Create a [`Scalar`] form a row in an arrow array. - pub fn from_array(arr: &dyn Array, index: usize) -> Option { - use arrow_array::*; - use arrow_schema::DataType::*; - - if arr.len() <= index { - return None; - } - if arr.is_null(index) { - return Some(Self::Null(arr.data_type().try_into().ok()?)); - } - - match arr.data_type() { - Utf8 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::String(v.value(index).to_string())), - LargeUtf8 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::String(v.value(index).to_string())), - Boolean => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Boolean(v.value(index))), - Binary => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Binary(v.value(index).to_vec())), - LargeBinary => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Binary(v.value(index).to_vec())), - FixedSizeBinary(_) => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Binary(v.value(index).to_vec())), - Int8 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Byte(v.value(index))), - Int16 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Short(v.value(index))), - Int32 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Integer(v.value(index))), - Int64 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Long(v.value(index))), - UInt8 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Byte(v.value(index) as i8)), - UInt16 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Short(v.value(index) as i16)), - UInt32 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Integer(v.value(index) as i32)), - UInt64 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Long(v.value(index) as i64)), - Float32 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Float(v.value(index))), - Float64 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Double(v.value(index))), - Decimal128(precision, scale) => { - arr.as_any().downcast_ref::().map(|v| { - let value = v.value(index); - Self::Decimal(value, *precision, *scale) - }) - } - Date32 => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Date(v.value(index))), - // TODO handle timezones when implementing timestamp ntz feature. - Timestamp(TimeUnit::Microsecond, tz) => match tz { - None => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Timestamp(v.value(index))), - Some(tz_str) if tz_str.as_ref() == "UTC" => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Timestamp(v.clone().with_timezone("UTC").value(index))), - _ => None, - }, - Struct(fields) => { - let struct_fields = fields - .iter() - .flat_map(|f| TryFrom::try_from(f.as_ref())) - .collect::>(); - let values = arr - .as_any() - .downcast_ref::() - .and_then(|struct_arr| { - struct_fields - .iter() - .map(|f: &StructField| { - struct_arr - .column_by_name(f.name()) - .and_then(|c| Self::from_array(c.as_ref(), index)) - }) - .collect::>>() - })?; - if struct_fields.len() != values.len() { - return None; - } - Some(Self::Struct(values, struct_fields)) - } - Float16 - | Decimal256(_, _) - | List(_) - | LargeList(_) - | FixedSizeList(_, _) - | Map(_, _) - | Date64 - | Timestamp(_, _) - | Time32(_) - | Time64(_) - | Duration(_) - | Interval(_) - | Dictionary(_, _) - | RunEndEncoded(_, _) - | Union(_, _) - | Utf8View - | BinaryView - | ListView(_) - | LargeListView(_) - | Null => None, - } - } -} - -impl PartialOrd for Scalar { - fn partial_cmp(&self, other: &Self) -> Option { - use Scalar::*; - match (self, other) { - (Null(_), Null(_)) => Some(Ordering::Equal), - (Integer(a), Integer(b)) => a.partial_cmp(b), - (Long(a), Long(b)) => a.partial_cmp(b), - (Short(a), Short(b)) => a.partial_cmp(b), - (Byte(a), Byte(b)) => a.partial_cmp(b), - (Float(a), Float(b)) => a.partial_cmp(b), - (Double(a), Double(b)) => a.partial_cmp(b), - (String(a), String(b)) => a.partial_cmp(b), - (Boolean(a), Boolean(b)) => a.partial_cmp(b), - (Timestamp(a), Timestamp(b)) => a.partial_cmp(b), - (TimestampNtz(a), TimestampNtz(b)) => a.partial_cmp(b), - (Date(a), Date(b)) => a.partial_cmp(b), - (Binary(a), Binary(b)) => a.partial_cmp(b), - (Decimal(a, _, _), Decimal(b, _, _)) => a.partial_cmp(b), - (Struct(a, _), Struct(b, _)) => a.partial_cmp(b), - // TODO should we make an assumption about the ordering of nulls? - // rigth now this is only used for internal purposes. - (Null(_), _) => Some(Ordering::Less), - (_, Null(_)) => Some(Ordering::Greater), - _ => None, - } - } -} - -impl Display for Scalar { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Integer(i) => write!(f, "{}", i), - Self::Long(i) => write!(f, "{}", i), - Self::Short(i) => write!(f, "{}", i), - Self::Byte(i) => write!(f, "{}", i), - Self::Float(fl) => write!(f, "{}", fl), - Self::Double(fl) => write!(f, "{}", fl), - Self::String(s) => write!(f, "'{}'", s), - Self::Boolean(b) => write!(f, "{}", b), - Self::Timestamp(ts) => write!(f, "{}", ts), - Self::TimestampNtz(ts) => write!(f, "{}", ts), - Self::Date(d) => write!(f, "{}", d), - Self::Binary(b) => write!(f, "{:?}", b), - Self::Decimal(value, _, scale) => match scale.cmp(&0) { - Ordering::Equal => { - write!(f, "{}", value) - } - Ordering::Greater => { - let scalar_multiple = 10_i128.pow(*scale as u32); - write!(f, "{}", value / scalar_multiple)?; - write!(f, ".")?; - write!( - f, - "{:0>scale$}", - value % scalar_multiple, - scale = *scale as usize - ) - } - Ordering::Less => { - write!(f, "{}", value)?; - for _ in 0..(scale.abs()) { - write!(f, "0")?; - } - Ok(()) - } - }, - Self::Null(_) => write!(f, "null"), - Self::Struct(values, fields) => { - write!(f, "{{")?; - for (i, (value, field)) in values.iter().zip(fields.iter()).enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}: {}", field.name, value)?; - } - write!(f, "}}") - } - } - } -} - -impl From for Scalar { - fn from(i: i32) -> Self { - Self::Integer(i) - } -} - -impl From for Scalar { - fn from(i: i64) -> Self { - Self::Long(i) - } -} - -impl From for Scalar { - fn from(b: bool) -> Self { - Self::Boolean(b) - } -} - -impl From<&str> for Scalar { - fn from(s: &str) -> Self { - Self::String(s.into()) - } -} - -impl From for Scalar { - fn from(value: String) -> Self { - Self::String(value) - } -} - -// TODO: add more From impls - -impl PrimitiveType { - fn data_type(&self) -> DataType { - DataType::Primitive(*self) - } - - /// Parses a string into a scalar value. - pub fn parse_scalar(&self, raw: &str) -> Result { - use PrimitiveType::*; - - lazy_static::lazy_static! { - static ref UNIX_EPOCH: DateTime = DateTime::from_timestamp(0, 0).unwrap(); - } - - if raw.is_empty() || raw == NULL_PARTITION_VALUE_DATA_PATH { - return Ok(Scalar::Null(self.data_type())); - } - - match self { - String => Ok(Scalar::String(raw.to_string())), - Byte => self.str_parse_scalar(raw, Scalar::Byte), - Short => self.str_parse_scalar(raw, Scalar::Short), - Integer => self.str_parse_scalar(raw, Scalar::Integer), - Long => self.str_parse_scalar(raw, Scalar::Long), - Float => self.str_parse_scalar(raw, Scalar::Float), - Double => self.str_parse_scalar(raw, Scalar::Double), - Boolean => { - if raw.eq_ignore_ascii_case("true") { - Ok(Scalar::Boolean(true)) - } else if raw.eq_ignore_ascii_case("false") { - Ok(Scalar::Boolean(false)) - } else { - Err(self.parse_error(raw)) - } - } - Date => { - let date = NaiveDate::parse_from_str(raw, "%Y-%m-%d") - .map_err(|_| self.parse_error(raw))? - .and_hms_opt(0, 0, 0) - .ok_or(self.parse_error(raw))?; - let date = Utc.from_utc_datetime(&date); - let days = date.signed_duration_since(*UNIX_EPOCH).num_days() as i32; - Ok(Scalar::Date(days)) - } - Timestamp => { - let timestamp = NaiveDateTime::parse_from_str(raw, "%Y-%m-%d %H:%M:%S%.f") - .map_err(|_| self.parse_error(raw))?; - let timestamp = Utc.from_utc_datetime(×tamp); - let micros = timestamp - .signed_duration_since(*UNIX_EPOCH) - .num_microseconds() - .ok_or(self.parse_error(raw))?; - Ok(Scalar::Timestamp(micros)) - } - TimestampNtz => { - let timestamp = NaiveDateTime::parse_from_str(raw, "%Y-%m-%d %H:%M:%S%.f") - .map_err(|_| self.parse_error(raw))?; - let timestamp = Utc.from_utc_datetime(×tamp); - let micros = timestamp - .signed_duration_since(*UNIX_EPOCH) - .num_microseconds() - .ok_or(self.parse_error(raw))?; - Ok(Scalar::TimestampNtz(micros)) - } - Binary => { - let bytes = parse_escaped_binary_string(raw).map_err(|_| self.parse_error(raw))?; - Ok(Scalar::Binary(bytes)) - } - _ => todo!("parsing {:?} is not yet supported", self), - } - } - - fn parse_error(&self, raw: &str) -> Error { - Error::Parse(raw.to_string(), self.data_type()) - } - - fn str_parse_scalar( - &self, - raw: &str, - f: impl FnOnce(T) -> Scalar, - ) -> Result { - match raw.parse() { - Ok(val) => Ok(f(val)), - Err(..) => Err(self.parse_error(raw)), - } - } -} - -fn create_escaped_binary_string(data: &[u8]) -> String { - let mut escaped_string = String::new(); - for &byte in data { - // Convert each byte to its two-digit hexadecimal representation - let hex_representation = format!("{:04X}", byte); - // Append the hexadecimal representation with an escape sequence - escaped_string.push_str("\\u"); - escaped_string.push_str(&hex_representation); - } - escaped_string -} - -fn parse_escaped_binary_string(escaped_string: &str) -> Result, &'static str> { - let mut parsed_bytes = Vec::new(); - let mut chars = escaped_string.chars(); - - while let Some(ch) = chars.next() { - if ch == '\\' { - // Check for the escape sequence "\\u" indicating a hexadecimal value - if chars.next() == Some('u') { - // Read two hexadecimal digits and convert to u8 - if let (Some(digit1), Some(digit2), Some(digit3), Some(digit4)) = - (chars.next(), chars.next(), chars.next(), chars.next()) - { - if let Ok(byte) = - u8::from_str_radix(&format!("{}{}{}{}", digit1, digit2, digit3, digit4), 16) - { - parsed_bytes.push(byte); - } else { - return Err("Error parsing hexadecimal value"); - } - } else { - return Err("Incomplete escape sequence"); - } - } else { - // Unrecognized escape sequence - return Err("Unrecognized escape sequence"); - } - } else { - // Regular character, convert to u8 and push into the result vector - parsed_bytes.push(ch as u8); - } - } - - Ok(parsed_bytes) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_binary_roundtrip() { - let scalar = Scalar::Binary(vec![0, 1, 2, 3, 4, 5]); - let parsed = PrimitiveType::Binary - .parse_scalar(&scalar.serialize()) - .unwrap(); - assert_eq!(scalar, parsed); - } - - #[test] - fn test_decimal_display() { - let s = Scalar::Decimal(123456789, 9, 2); - assert_eq!(s.to_string(), "1234567.89"); - - let s = Scalar::Decimal(123456789, 9, 0); - assert_eq!(s.to_string(), "123456789"); - - let s = Scalar::Decimal(123456789, 9, 9); - assert_eq!(s.to_string(), "0.123456789"); - - let s = Scalar::Decimal(123, 9, -3); - assert_eq!(s.to_string(), "123000"); - } -} diff --git a/crates/core/src/kernel/mod.rs b/crates/core/src/kernel/mod.rs index 876a09a33c..ce788d6c4d 100644 --- a/crates/core/src/kernel/mod.rs +++ b/crates/core/src/kernel/mod.rs @@ -4,12 +4,11 @@ pub mod arrow; pub mod error; -pub mod expressions; pub mod models; +pub mod scalars; mod snapshot; pub use error::*; -pub use expressions::*; pub use models::*; pub use snapshot::*; diff --git a/crates/core/src/kernel/models/fields.rs b/crates/core/src/kernel/models/fields.rs index fa672aaefc..6c699f0e88 100644 --- a/crates/core/src/kernel/models/fields.rs +++ b/crates/core/src/kernel/models/fields.rs @@ -1,8 +1,8 @@ //! Schema definitions for action types +use delta_kernel::schema::{ArrayType, DataType, MapType, StructField, StructType}; use lazy_static::lazy_static; -use super::schema::{ArrayType, DataType, MapType, StructField, StructType}; use super::ActionType; impl ActionType { diff --git a/crates/core/src/kernel/models/schema.rs b/crates/core/src/kernel/models/schema.rs index 161de0352a..3a88564f1d 100644 --- a/crates/core/src/kernel/models/schema.rs +++ b/crates/core/src/kernel/models/schema.rs @@ -1,93 +1,21 @@ //! Delta table schema -use std::borrow::Borrow; -use std::fmt::Formatter; -use std::hash::{Hash, Hasher}; use std::sync::Arc; -use std::{collections::HashMap, fmt::Display}; -use serde::{Deserialize, Serialize}; +pub use delta_kernel::schema::{ + ArrayType, ColumnMetadataKey, DataType, MapType, MetadataValue, PrimitiveType, StructField, + StructType, +}; use serde_json::Value; use crate::kernel::error::Error; use crate::kernel::DataCheck; -use crate::protocol::ProtocolError; /// Type alias for a top level schema pub type Schema = StructType; /// Schema reference type pub type SchemaRef = Arc; -/// A value that can be stored in the metadata of a Delta table schema entity. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -#[serde(untagged)] -pub enum MetadataValue { - /// A number value - Number(i32), - /// A string value - String(String), - /// A Boolean value - Boolean(bool), -} - -impl From for MetadataValue { - fn from(value: String) -> Self { - Self::String(value) - } -} - -impl From<&String> for MetadataValue { - fn from(value: &String) -> Self { - Self::String(value.clone()) - } -} - -impl From for MetadataValue { - fn from(value: i32) -> Self { - Self::Number(value) - } -} - -impl From for MetadataValue { - fn from(value: bool) -> Self { - Self::Boolean(value) - } -} - -impl From for MetadataValue { - fn from(value: Value) -> Self { - Self::String(value.to_string()) - } -} - -#[derive(Debug)] -#[allow(missing_docs)] -pub enum ColumnMetadataKey { - ColumnMappingId, - ColumnMappingPhysicalName, - GenerationExpression, - IdentityStart, - IdentityStep, - IdentityHighWaterMark, - IdentityAllowExplicitInsert, - Invariants, -} - -impl AsRef for ColumnMetadataKey { - fn as_ref(&self) -> &str { - match self { - Self::ColumnMappingId => "delta.columnMapping.id", - Self::ColumnMappingPhysicalName => "delta.columnMapping.physicalName", - Self::GenerationExpression => "delta.generationExpression", - Self::IdentityAllowExplicitInsert => "delta.identity.allowExplicitInsert", - Self::IdentityHighWaterMark => "delta.identity.highWaterMark", - Self::IdentityStart => "delta.identity.start", - Self::IdentityStep => "delta.identity.step", - Self::Invariants => "delta.invariants", - } - } -} - /// An invariant for a column that is enforced on all writes to a Delta table. #[derive(Eq, PartialEq, Debug, Default, Clone)] pub struct Invariant { @@ -117,168 +45,17 @@ impl DataCheck for Invariant { } } -/// Represents a struct field defined in the Delta table schema. -// https://github.com/delta-io/delta/blob/master/PROTOCOL.md#Schema-Serialization-Format -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct StructField { - /// Name of this (possibly nested) column - pub name: String, - /// The data type of this field - #[serde(rename = "type")] - pub data_type: DataType, - /// Denotes whether this Field can be null - pub nullable: bool, - /// A JSON map containing information about this column - pub metadata: HashMap, -} - -impl Hash for StructField { - fn hash(&self, state: &mut H) { - self.name.hash(state); - self.data_type.hash(state); - self.nullable.hash(state); - } -} - -impl Borrow for StructField { - fn borrow(&self) -> &str { - self.name.as_ref() - } -} - -impl Eq for StructField {} - -impl StructField { - /// Creates a new field - pub fn new(name: impl Into, data_type: impl Into, nullable: bool) -> Self { - Self { - name: name.into(), - data_type: data_type.into(), - nullable, - metadata: HashMap::default(), - } - } - - /// Creates a new field with metadata - pub fn with_metadata( - mut self, - metadata: impl IntoIterator, impl Into)>, - ) -> Self { - self.metadata = metadata - .into_iter() - .map(|(k, v)| (k.into(), v.into())) - .collect(); - self - } - - /// Get the value of a specific metadata key - pub fn get_config_value(&self, key: &ColumnMetadataKey) -> Option<&MetadataValue> { - self.metadata.get(key.as_ref()) - } - - #[inline] - /// Returns the name of the column - pub fn name(&self) -> &String { - &self.name - } - - #[inline] - /// Returns whether the column is nullable - pub fn is_nullable(&self) -> bool { - self.nullable - } - - /// Returns the physical name of the column - /// Equals the name if column mapping is not enabled on table - pub fn physical_name(&self) -> Result<&str, Error> { - // Even on mapping type id the physical name should be there for partitions - let phys_name = self.get_config_value(&ColumnMetadataKey::ColumnMappingPhysicalName); - match phys_name { - None => Ok(&self.name), - Some(MetadataValue::Boolean(_)) => Ok(&self.name), - Some(MetadataValue::String(s)) => Ok(s), - Some(MetadataValue::Number(_)) => Err(Error::MetadataError( - "Unexpected type for physical name".to_string(), - )), - } - } - - #[inline] - /// Returns the data type of the column - pub const fn data_type(&self) -> &DataType { - &self.data_type - } - - #[inline] - /// Returns the metadata of the column - pub const fn metadata(&self) -> &HashMap { - &self.metadata - } -} - -/// A struct is used to represent both the top-level schema of the table -/// as well as struct columns that contain nested columns. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq, Hash)] -pub struct StructType { - #[serde(rename = "type")] - /// The type of this struct - pub type_name: String, - /// The type of element stored in this array - pub fields: Vec, +/// Trait to add convenince functions to struct type +pub trait StructTypeExt { + /// Get all invariants in the schemas + fn get_invariants(&self) -> Result, Error>; } -impl StructType { - /// Creates a new struct type - pub fn new(fields: Vec) -> Self { - Self { - type_name: "struct".into(), - fields, - } - } - - /// Returns an immutable reference of the fields in the struct - pub fn fields(&self) -> &Vec { - &self.fields - } - - /// Find the index of the column with the given name. - pub fn index_of(&self, name: &str) -> Result { - let (idx, _) = self - .fields() - .iter() - .enumerate() - .find(|(_, b)| b.name() == name) - .ok_or_else(|| { - let valid_fields: Vec<_> = self.fields.iter().map(|f| f.name()).collect(); - Error::Schema(format!( - "Unable to get field named \"{name}\". Valid fields: {valid_fields:?}" - )) - })?; - Ok(idx) - } - - /// Returns a reference of a specific [`StructField`] instance selected by name. - pub fn field_with_name(&self, name: &str) -> Result<&StructField, Error> { - match name.split_once('.') { - Some((parent, children)) => { - let parent_field = &self.fields[self.index_of(parent)?]; - match parent_field.data_type { - DataType::Struct(ref inner) => Ok(inner.field_with_name(children)?), - _ => Err(Error::Schema(format!( - "Field {} is not a struct type", - parent_field.name() - ))), - } - } - None => Ok(&self.fields[self.index_of(name)?]), - } - } - +impl StructTypeExt for StructType { /// Get all invariants in the schemas - pub fn get_invariants(&self) -> Result, Error> { + fn get_invariants(&self) -> Result, Error> { let mut remaining_fields: Vec<(String, StructField)> = self .fields() - .iter() .map(|field| (field.name.clone(), field.clone())) .collect(); let mut invariants: Vec = Vec::new(); @@ -297,7 +74,6 @@ impl StructType { remaining_fields.extend( inner .fields() - .iter() .map(|field| { let new_prefix = add_segment(&field_path, &field.name); (new_prefix, field.clone()) @@ -349,521 +125,11 @@ impl StructType { } } -impl FromIterator for StructType { - fn from_iter>(iter: T) -> Self { - Self { - type_name: "struct".into(), - fields: iter.into_iter().collect(), - } - } -} - -impl<'a> FromIterator<&'a StructField> for StructType { - fn from_iter>(iter: T) -> Self { - Self { - type_name: "struct".into(), - fields: iter.into_iter().cloned().collect(), - } - } -} - -impl From<[StructField; N]> for StructType { - fn from(value: [StructField; N]) -> Self { - Self { - type_name: "struct".into(), - fields: value.to_vec(), - } - } -} - -impl<'a, const N: usize> From<[&'a StructField; N]> for StructType { - fn from(value: [&'a StructField; N]) -> Self { - Self { - type_name: "struct".into(), - fields: value.into_iter().cloned().collect(), - } - } -} - -impl<'a> IntoIterator for &'a StructType { - type Item = &'a StructField; - type IntoIter = std::slice::Iter<'a, StructField>; - - fn into_iter(self) -> Self::IntoIter { - self.fields.iter() - } -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq, Hash)] -#[serde(rename_all = "camelCase")] -/// An array stores a variable length collection of items of some type. -pub struct ArrayType { - #[serde(rename = "type")] - /// The type of this struct - pub type_name: String, - /// The type of element stored in this array - pub element_type: DataType, - /// Denoting whether this array can contain one or more null values - pub contains_null: bool, -} - -impl ArrayType { - /// Creates a new array type - pub fn new(element_type: DataType, contains_null: bool) -> Self { - Self { - type_name: "array".into(), - element_type, - contains_null, - } - } - - #[inline] - /// Returns the element type of the array - pub const fn element_type(&self) -> &DataType { - &self.element_type - } - - #[inline] - /// Returns whether the array can contain null values - pub const fn contains_null(&self) -> bool { - self.contains_null - } -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq, Hash)] -#[serde(rename_all = "camelCase")] -/// A map stores an arbitrary length collection of key-value pairs -pub struct MapType { - #[serde(rename = "type")] - /// The type of this struct - pub type_name: String, - /// The type of element used for the key of this map - pub key_type: DataType, - /// The type of element used for the value of this map - pub value_type: DataType, - /// Denoting whether this array can contain one or more null values - #[serde(default = "default_true")] - pub value_contains_null: bool, -} - -impl MapType { - /// Creates a new map type - pub fn new(key_type: DataType, value_type: DataType, value_contains_null: bool) -> Self { - Self { - type_name: "map".into(), - key_type, - value_type, - value_contains_null, - } - } - - #[inline] - /// Returns the key type of the map - pub const fn key_type(&self) -> &DataType { - &self.key_type - } - - #[inline] - /// Returns the value type of the map - pub const fn value_type(&self) -> &DataType { - &self.value_type - } - - #[inline] - /// Returns whether the map can contain null values - pub const fn value_contains_null(&self) -> bool { - self.value_contains_null - } -} - -fn default_true() -> bool { - true -} - -/// The maximum precision for [PrimitiveType::Decimal] values -pub const DECIMAL_MAX_PRECISION: u8 = 38; - -/// The maximum scale for [PrimitiveType::Decimal] values -pub const DECIMAL_MAX_SCALE: i8 = 38; - -#[derive(Debug, Serialize, Deserialize, PartialEq, Copy, Clone, Eq, Hash)] -#[serde(rename_all = "snake_case")] -/// Primitive types supported by Delta -pub enum PrimitiveType { - /// UTF-8 encoded string of characters - String, - /// i64: 8-byte signed integer. Range: -9223372036854775808 to 9223372036854775807 - Long, - /// i32: 4-byte signed integer. Range: -2147483648 to 2147483647 - Integer, - /// i16: 2-byte signed integer numbers. Range: -32768 to 32767 - Short, - /// i8: 1-byte signed integer number. Range: -128 to 127 - Byte, - /// f32: 4-byte single-precision floating-point numbers - Float, - /// f64: 8-byte double-precision floating-point numbers - Double, - /// bool: boolean values - Boolean, - /// Binary: uninterpreted binary data - Binary, - /// Date: Calendar date (year, month, day) - Date, - /// Microsecond precision timestamp, adjusted to UTC. - Timestamp, - /// Micrsoecond precision timestamp with no timezone - #[serde(alias = "timestampNtz")] - TimestampNtz, - #[serde( - serialize_with = "serialize_decimal", - deserialize_with = "deserialize_decimal", - untagged - )] - /// Decimal: arbitrary precision decimal numbers - Decimal(u8, i8), -} - -fn serialize_decimal( - precision: &u8, - scale: &i8, - serializer: S, -) -> Result { - serializer.serialize_str(&format!("decimal({},{})", precision, scale)) -} - -fn deserialize_decimal<'de, D>(deserializer: D) -> Result<(u8, i8), D::Error> -where - D: serde::Deserializer<'de>, -{ - let str_value = String::deserialize(deserializer)?; - if !str_value.starts_with("decimal(") || !str_value.ends_with(')') { - return Err(serde::de::Error::custom(format!( - "Invalid decimal: {}", - str_value - ))); - } - - let mut parts = str_value[8..str_value.len() - 1].split(','); - let precision = parts - .next() - .and_then(|part| part.trim().parse::().ok()) - .ok_or_else(|| { - serde::de::Error::custom(format!("Invalid precision in decimal: {}", str_value)) - })?; - let scale = parts - .next() - .and_then(|part| part.trim().parse::().ok()) - .ok_or_else(|| { - serde::de::Error::custom(format!("Invalid scale in decimal: {}", str_value)) - })?; - if precision > DECIMAL_MAX_PRECISION || scale > DECIMAL_MAX_SCALE { - return Err(serde::de::Error::custom(format!( - "Precision or scale is larger than 38: {}, {}", - precision, scale - ))); - } - Ok((precision, scale)) -} - -impl Display for PrimitiveType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - PrimitiveType::String => write!(f, "string"), - PrimitiveType::Long => write!(f, "long"), - PrimitiveType::Integer => write!(f, "integer"), - PrimitiveType::Short => write!(f, "short"), - PrimitiveType::Byte => write!(f, "byte"), - PrimitiveType::Float => write!(f, "float"), - PrimitiveType::Double => write!(f, "double"), - PrimitiveType::Boolean => write!(f, "boolean"), - PrimitiveType::Binary => write!(f, "binary"), - PrimitiveType::Date => write!(f, "date"), - PrimitiveType::Timestamp => write!(f, "timestamp"), - PrimitiveType::TimestampNtz => write!(f, "timestampNtz"), - PrimitiveType::Decimal(precision, scale) => { - write!(f, "decimal({},{})", precision, scale) - } - } - } -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq, Hash)] -#[serde(untagged, rename_all = "camelCase")] -/// Top level delta tdatatypes -pub enum DataType { - /// UTF-8 encoded string of characters - Primitive(PrimitiveType), - /// An array stores a variable length collection of items of some type. - Array(Box), - /// A struct is used to represent both the top-level schema of the table as well - /// as struct columns that contain nested columns. - Struct(Box), - /// A map stores an arbitrary length collection of key-value pairs - /// with a single keyType and a single valueType - Map(Box), -} - -impl From for DataType { - fn from(map_type: MapType) -> Self { - DataType::Map(Box::new(map_type)) - } -} - -impl From for DataType { - fn from(struct_type: StructType) -> Self { - DataType::Struct(Box::new(struct_type)) - } -} - -impl From for DataType { - fn from(array_type: ArrayType) -> Self { - DataType::Array(Box::new(array_type)) - } -} - -#[allow(missing_docs)] -impl DataType { - pub const STRING: Self = DataType::Primitive(PrimitiveType::String); - pub const LONG: Self = DataType::Primitive(PrimitiveType::Long); - pub const INTEGER: Self = DataType::Primitive(PrimitiveType::Integer); - pub const SHORT: Self = DataType::Primitive(PrimitiveType::Short); - pub const BYTE: Self = DataType::Primitive(PrimitiveType::Byte); - pub const FLOAT: Self = DataType::Primitive(PrimitiveType::Float); - pub const DOUBLE: Self = DataType::Primitive(PrimitiveType::Double); - pub const BOOLEAN: Self = DataType::Primitive(PrimitiveType::Boolean); - pub const BINARY: Self = DataType::Primitive(PrimitiveType::Binary); - pub const DATE: Self = DataType::Primitive(PrimitiveType::Date); - pub const TIMESTAMP: Self = DataType::Primitive(PrimitiveType::Timestamp); - pub const TIMESTAMPNTZ: Self = DataType::Primitive(PrimitiveType::TimestampNtz); - - pub fn decimal(precision: u8, scale: i8) -> Result { - if precision > DECIMAL_MAX_PRECISION || scale > DECIMAL_MAX_SCALE { - return Err(ProtocolError::InvalidField(format!( - "decimal({},{})", - precision, scale - ))); - } - Ok(DataType::Primitive(PrimitiveType::Decimal( - precision, scale, - ))) - } - - pub fn struct_type(fields: Vec) -> Self { - DataType::Struct(Box::new(StructType::new(fields))) - } -} - -impl Display for DataType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - DataType::Primitive(p) => write!(f, "{}", p), - DataType::Array(a) => write!(f, "array<{}>", a.element_type), - DataType::Struct(s) => { - write!(f, "struct<")?; - for (i, field) in s.fields.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}: {}", field.name, field.data_type)?; - } - write!(f, ">") - } - DataType::Map(m) => write!(f, "map<{}, {}>", m.key_type, m.value_type), - } - } -} - #[cfg(test)] mod tests { use super::*; use serde_json; use serde_json::json; - use std::collections::hash_map::DefaultHasher; - - #[test] - fn test_serde_data_types() { - let data = r#" - { - "name": "a", - "type": "integer", - "nullable": false, - "metadata": {} - } - "#; - let field: StructField = serde_json::from_str(data).unwrap(); - assert!(matches!( - field.data_type, - DataType::Primitive(PrimitiveType::Integer) - )); - - let data = r#" - { - "name": "c", - "type": { - "type": "array", - "elementType": "integer", - "containsNull": false - }, - "nullable": true, - "metadata": {} - } - "#; - let field: StructField = serde_json::from_str(data).unwrap(); - assert!(matches!(field.data_type, DataType::Array(_))); - - let data = r#" - { - "name": "e", - "type": { - "type": "array", - "elementType": { - "type": "struct", - "fields": [ - { - "name": "d", - "type": "integer", - "nullable": false, - "metadata": {} - } - ] - }, - "containsNull": true - }, - "nullable": true, - "metadata": {} - } - "#; - let field: StructField = serde_json::from_str(data).unwrap(); - assert!(matches!(field.data_type, DataType::Array(_))); - match field.data_type { - DataType::Array(array) => assert!(matches!(array.element_type, DataType::Struct(_))), - _ => unreachable!(), - } - - let data = r#" - { - "name": "f", - "type": { - "type": "map", - "keyType": "string", - "valueType": "string", - "valueContainsNull": true - }, - "nullable": true, - "metadata": {} - } - "#; - let field: StructField = serde_json::from_str(data).unwrap(); - assert!(matches!(field.data_type, DataType::Map(_))); - } - - #[test] - fn test_roundtrip_decimal() { - let data = r#" - { - "name": "a", - "type": "decimal(10, 2)", - "nullable": false, - "metadata": {} - } - "#; - let field: StructField = serde_json::from_str(data).unwrap(); - assert!(matches!( - field.data_type, - DataType::Primitive(PrimitiveType::Decimal(10, 2)) - )); - - let json_str = serde_json::to_string(&field).unwrap(); - assert_eq!( - json_str, - r#"{"name":"a","type":"decimal(10,2)","nullable":false,"metadata":{}}"# - ); - } - - #[test] - fn test_invalid_decimal() { - let data = r#" - { - "name": "a", - "type": "decimal(39, 10)", - "nullable": false, - "metadata": {} - } - "#; - assert!(matches!( - serde_json::from_str::(data).unwrap_err(), - _ - )); - - let data = r#" - { - "name": "a", - "type": "decimal(10, 39)", - "nullable": false, - "metadata": {} - } - "#; - assert!(matches!( - serde_json::from_str::(data).unwrap_err(), - _ - )); - } - - #[test] - fn test_field_metadata() { - let data = r#" - { - "name": "e", - "type": { - "type": "array", - "elementType": { - "type": "struct", - "fields": [ - { - "name": "d", - "type": "integer", - "nullable": false, - "metadata": { - "delta.columnMapping.id": 5, - "delta.columnMapping.physicalName": "col-a7f4159c-53be-4cb0-b81a-f7e5240cfc49" - } - } - ] - }, - "containsNull": true - }, - "nullable": true, - "metadata": { - "delta.columnMapping.id": 4, - "delta.columnMapping.physicalName": "col-5f422f40-de70-45b2-88ab-1d5c90e94db1" - } - } - "#; - let field: StructField = serde_json::from_str(data).unwrap(); - - let col_id = field - .get_config_value(&ColumnMetadataKey::ColumnMappingId) - .unwrap(); - assert!(matches!(col_id, MetadataValue::Number(num) if *num == 4)); - let physical_name = field - .get_config_value(&ColumnMetadataKey::ColumnMappingPhysicalName) - .unwrap(); - assert!( - matches!(physical_name, MetadataValue::String(name) if *name == "col-5f422f40-de70-45b2-88ab-1d5c90e94db1") - ); - } - - #[test] - fn test_read_schemas() { - let file = std::fs::File::open("./tests/serde/schema.json").unwrap(); - let schema: Result = serde_json::from_reader(file); - assert!(schema.is_ok()); - - let file = std::fs::File::open("./tests/serde/checkpoint_schema.json").unwrap(); - let schema: Result = serde_json::from_reader(file); - assert!(schema.is_ok()) - } #[test] fn test_get_invariants() { @@ -934,88 +200,4 @@ mod tests { let buf = r#"{"type":"struct","fields":[{"name":"ID_D_DATE","type":"long","nullable":true,"metadata":{"delta.identity.start":1,"delta.identity.step":1,"delta.identity.allowExplicitInsert":false}},{"name":"TXT_DateKey","type":"string","nullable":true,"metadata":{}}]}"#; let _schema: StructType = serde_json::from_str(buf).expect("Failed to load"); } - - fn get_hash(field: &StructField) -> u64 { - let mut hasher = DefaultHasher::new(); - field.hash(&mut hasher); - hasher.finish() - } - - #[test] - fn test_hash_struct_field() { - // different names should result in different hashes - let field_1 = StructField::new( - "field_name_1", - DataType::Primitive(PrimitiveType::Decimal(4, 4)), - true, - ); - let field_2 = StructField::new( - "field_name_2", - DataType::Primitive(PrimitiveType::Decimal(4, 4)), - true, - ); - assert_ne!(get_hash(&field_1), get_hash(&field_2)); - - // different types should result in different hashes - let field_int = StructField::new( - "field_name", - DataType::Primitive(PrimitiveType::Integer), - true, - ); - let field_string = StructField::new( - "field_name", - DataType::Primitive(PrimitiveType::String), - true, - ); - assert_ne!(get_hash(&field_int), get_hash(&field_string)); - - // different nullability should result in different hashes - let field_true = StructField::new( - "field_name", - DataType::Primitive(PrimitiveType::Binary), - true, - ); - let field_false = StructField::new( - "field_name", - DataType::Primitive(PrimitiveType::Binary), - false, - ); - assert_ne!(get_hash(&field_true), get_hash(&field_false)); - - // case where hashes are the same - let field_1 = StructField::new( - "field_name", - DataType::Primitive(PrimitiveType::Timestamp), - true, - ); - let field_2 = StructField::new( - "field_name", - DataType::Primitive(PrimitiveType::Timestamp), - true, - ); - assert_eq!(get_hash(&field_1), get_hash(&field_2)); - } - - #[test] - fn test_field_with_name() { - let schema = StructType::new(vec![ - StructField::new("a", DataType::STRING, true), - StructField::new("b", DataType::INTEGER, true), - ]); - let field = schema.field_with_name("b").unwrap(); - assert_eq!(*field, StructField::new("b", DataType::INTEGER, true)); - } - - #[test] - fn test_field_with_name_nested() { - let nested = StructType::new(vec![StructField::new("a", DataType::BOOLEAN, true)]); - let schema = StructType::new(vec![ - StructField::new("a", DataType::STRING, true), - StructField::new("b", DataType::Struct(Box::new(nested)), true), - ]); - - let field = schema.field_with_name("b.a").unwrap(); - - assert_eq!(*field, StructField::new("a", DataType::BOOLEAN, true)); - } } diff --git a/crates/core/src/kernel/scalars.rs b/crates/core/src/kernel/scalars.rs new file mode 100644 index 0000000000..c596bd9e10 --- /dev/null +++ b/crates/core/src/kernel/scalars.rs @@ -0,0 +1,233 @@ +//! Auxiliary methods for dealing with kernel scalars +//! +use std::cmp::Ordering; + +use arrow_array::Array; +use arrow_schema::TimeUnit; +use chrono::{DateTime, TimeZone, Utc}; +use delta_kernel::{ + expressions::{Scalar, StructData}, + schema::StructField, +}; +use object_store::path::Path; + +use crate::NULL_PARTITION_VALUE_DATA_PATH; + +/// Auxiliary methods for dealing with kernel scalars +pub trait ScalarExt: Sized { + /// Serialize to string + fn serialize(&self) -> String; + /// Serialize to string for use in hive partition file names + fn serialize_encoded(&self) -> String; + /// Create a [`Scalar`] from an arrow array row + fn from_array(arr: &dyn Array, index: usize) -> Option; +} + +impl ScalarExt for Scalar { + /// Serializes this scalar as a string. + fn serialize(&self) -> String { + match self { + Self::String(s) => s.to_owned(), + Self::Byte(b) => b.to_string(), + Self::Short(s) => s.to_string(), + Self::Integer(i) => i.to_string(), + Self::Long(l) => l.to_string(), + Self::Float(f) => f.to_string(), + Self::Double(d) => d.to_string(), + Self::Boolean(b) => if *b { "true" } else { "false" }.to_string(), + Self::TimestampNtz(ts) | Self::Timestamp(ts) => { + let ts = Utc.timestamp_micros(*ts).single().unwrap(); + ts.format("%Y-%m-%d %H:%M:%S%.6f").to_string() + } + Self::Date(days) => { + let date = DateTime::from_timestamp(*days as i64 * 24 * 3600, 0).unwrap(); + date.format("%Y-%m-%d").to_string() + } + Self::Decimal(value, _, scale) => match scale.cmp(&0) { + Ordering::Equal => value.to_string(), + Ordering::Greater => { + let scalar_multiple = 10_i128.pow(*scale as u32); + let mut s = String::new(); + s.push_str((value / scalar_multiple).to_string().as_str()); + s.push('.'); + s.push_str(&format!( + "{:0>scale$}", + value % scalar_multiple, + scale = *scale as usize + )); + s + } + Ordering::Less => { + let mut s = value.to_string(); + for _ in 0..*scale { + s.push('0'); + } + s + } + }, + Self::Binary(val) => create_escaped_binary_string(val.as_slice()), + Self::Null(_) => "null".to_string(), + Self::Struct(_) => unimplemented!(), + } + } + + /// Serializes this scalar as a string for use in hive partition file names. + fn serialize_encoded(&self) -> String { + if self.is_null() { + return NULL_PARTITION_VALUE_DATA_PATH.to_string(); + } + Path::from(self.serialize()).to_string() + } + + /// Create a [`Scalar`] form a row in an arrow array. + fn from_array(arr: &dyn Array, index: usize) -> Option { + use arrow_array::*; + use arrow_schema::DataType::*; + + if arr.len() <= index { + return None; + } + if arr.is_null(index) { + return Some(Self::Null(arr.data_type().try_into().ok()?)); + } + + match arr.data_type() { + Utf8 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::String(v.value(index).to_string())), + LargeUtf8 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::String(v.value(index).to_string())), + Boolean => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Boolean(v.value(index))), + Binary => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Binary(v.value(index).to_vec())), + LargeBinary => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Binary(v.value(index).to_vec())), + FixedSizeBinary(_) => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Binary(v.value(index).to_vec())), + Int8 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Byte(v.value(index))), + Int16 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Short(v.value(index))), + Int32 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Integer(v.value(index))), + Int64 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Long(v.value(index))), + UInt8 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Byte(v.value(index) as i8)), + UInt16 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Short(v.value(index) as i16)), + UInt32 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Integer(v.value(index) as i32)), + UInt64 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Long(v.value(index) as i64)), + Float32 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Float(v.value(index))), + Float64 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Double(v.value(index))), + Decimal128(precision, scale) => { + arr.as_any().downcast_ref::().map(|v| { + let value = v.value(index); + Self::Decimal(value, *precision, *scale as u8) + }) + } + Date32 => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Date(v.value(index))), + Timestamp(TimeUnit::Microsecond, None) => arr + .as_any() + .downcast_ref::() + .map(|v| Self::TimestampNtz(v.value(index))), + Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.eq_ignore_ascii_case("utc") => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Timestamp(v.clone().value(index))), + Struct(fields) => { + let struct_fields = fields + .iter() + .flat_map(|f| TryFrom::try_from(f.as_ref())) + .collect::>(); + let values = arr + .as_any() + .downcast_ref::() + .and_then(|struct_arr| { + struct_fields + .iter() + .map(|f: &StructField| { + struct_arr + .column_by_name(f.name()) + .and_then(|c| Self::from_array(c.as_ref(), index)) + }) + .collect::>>() + })?; + Some(Self::Struct( + StructData::try_new(struct_fields, values).ok()?, + )) + } + Float16 + | Decimal256(_, _) + | List(_) + | LargeList(_) + | FixedSizeList(_, _) + | Map(_, _) + | Date64 + | Timestamp(_, _) + | Time32(_) + | Time64(_) + | Duration(_) + | Interval(_) + | Dictionary(_, _) + | RunEndEncoded(_, _) + | Union(_, _) + | Utf8View + | BinaryView + | ListView(_) + | LargeListView(_) + | Null => None, + } + } +} + +fn create_escaped_binary_string(data: &[u8]) -> String { + let mut escaped_string = String::new(); + for &byte in data { + // Convert each byte to its two-digit hexadecimal representation + let hex_representation = format!("{:04X}", byte); + // Append the hexadecimal representation with an escape sequence + escaped_string.push_str("\\u"); + escaped_string.push_str(&hex_representation); + } + escaped_string +} diff --git a/crates/core/src/kernel/snapshot/log_data.rs b/crates/core/src/kernel/snapshot/log_data.rs index 24fae0ad75..254616691c 100644 --- a/crates/core/src/kernel/snapshot/log_data.rs +++ b/crates/core/src/kernel/snapshot/log_data.rs @@ -4,14 +4,16 @@ use std::sync::Arc; use arrow_array::{Array, Int32Array, Int64Array, MapArray, RecordBatch, StringArray, StructArray}; use chrono::{DateTime, Utc}; +use delta_kernel::expressions::Scalar; use indexmap::IndexMap; use object_store::path::Path; use object_store::ObjectMeta; use percent_encoding::percent_decode_str; +use super::super::scalars::ScalarExt; use crate::kernel::arrow::extract::{extract_and_cast, extract_and_cast_opt}; use crate::kernel::{ - DataType, DeletionVectorDescriptor, Metadata, Remove, Scalar, StructField, StructType, + DataType, DeletionVectorDescriptor, Metadata, Remove, StructField, StructType, }; use crate::{DeltaResult, DeltaTableError}; @@ -351,7 +353,16 @@ impl<'a> FileStatsAccessor<'a> { metadata .partition_columns .iter() - .map(|c| Ok((c.as_str(), schema.field_with_name(c.as_str())?))) + .map(|c| { + Ok(( + c.as_str(), + schema + .field(c.as_str()) + .ok_or(DeltaTableError::PartitionError { + partition: c.clone(), + })?, + )) + }) .collect::>>()?, ); let deletion_vector = extract_and_cast_opt::(data, "add.deletionVector"); @@ -670,7 +681,6 @@ mod datafusion { let column_statistics = self .schema .fields() - .iter() .map(|f| self.column_stats(f.name())) .collect::>>()?; Some(Statistics { diff --git a/crates/core/src/kernel/snapshot/mod.rs b/crates/core/src/kernel/snapshot/mod.rs index cd6cf8bb5f..d34b78fbed 100644 --- a/crates/core/src/kernel/snapshot/mod.rs +++ b/crates/core/src/kernel/snapshot/mod.rs @@ -315,8 +315,8 @@ impl Snapshot { let stats_fields = if let Some(stats_cols) = self.table_config().stats_columns() { stats_cols .iter() - .map(|col| match schema.field_with_name(col) { - Ok(field) => match field.data_type() { + .map(|col| match schema.field(col) { + Some(field) => match field.data_type() { DataType::Map(_) | DataType::Array(_) | &DataType::BINARY => { Err(DeltaTableError::Generic(format!( "Stats column {} has unsupported type {}", @@ -340,7 +340,7 @@ impl Snapshot { let num_indexed_cols = self.table_config().num_indexed_cols(); schema .fields - .iter() + .values() .enumerate() .filter_map(|(idx, f)| stats_field(idx, num_indexed_cols, f)) .collect() @@ -699,7 +699,6 @@ fn stats_field(idx: usize, num_indexed_cols: i32, field: &StructField) -> Option StructType::new( dt_struct .fields() - .iter() .flat_map(|f| stats_field(idx, num_indexed_cols, f)) .collect(), ), @@ -718,12 +717,7 @@ fn to_count_field(field: &StructField) -> Option { DataType::Map(_) | DataType::Array(_) | &DataType::BINARY => None, DataType::Struct(s) => Some(StructField::new( field.name(), - StructType::new( - s.fields() - .iter() - .filter_map(to_count_field) - .collect::>(), - ), + StructType::new(s.fields().filter_map(to_count_field).collect::>()), true, )), _ => Some(StructField::new(field.name(), DataType::LONG, true)), diff --git a/crates/core/src/operations/cast.rs b/crates/core/src/operations/cast.rs index b231346266..e0dd9f5639 100644 --- a/crates/core/src/operations/cast.rs +++ b/crates/core/src/operations/cast.rs @@ -34,13 +34,12 @@ pub(crate) fn merge_struct( left: &StructType, right: &StructType, ) -> Result { - let mut errors = Vec::with_capacity(left.fields().len()); + let mut errors = Vec::new(); let merged_fields: Result, ArrowError> = left .fields() - .iter() .map(|field| { - let right_field = right.field_with_name(field.name()); - if let Ok(right_field) = right_field { + let right_field = right.field(field.name()); + if let Some(right_field) = right_field { let type_or_not = merge_type(field.data_type(), right_field.data_type()); match type_or_not { Err(e) => { @@ -67,7 +66,7 @@ pub(crate) fn merge_struct( match merged_fields { Ok(mut fields) => { for field in right.fields() { - if !left.field_with_name(field.name()).is_ok() { + if !left.field(field.name()).is_some() { fields.push(field.clone()); } } @@ -200,18 +199,21 @@ pub fn cast_record_batch( #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::sync::Arc; + + use arrow::array::ArrayData; + use arrow_array::{Array, ArrayRef, ListArray, RecordBatch}; + use arrow_buffer::Buffer; + use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; + use itertools::Itertools; + use crate::kernel::{ ArrayType as DeltaArrayType, DataType as DeltaDataType, StructField as DeltaStructField, StructType as DeltaStructType, }; use crate::operations::cast::MetadataValue; use crate::operations::cast::{cast_record_batch, is_cast_required}; - use arrow::array::ArrayData; - use arrow_array::{Array, ArrayRef, ListArray, RecordBatch}; - use arrow_buffer::Buffer; - use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; - use std::collections::HashMap; - use std::sync::Arc; #[test] fn test_merge_schema_with_dict() { @@ -253,13 +255,14 @@ mod tests { .with_metadata(right_meta)]); let result = super::merge_struct(&left_schema, &right_schema).unwrap(); - assert_eq!(result.fields().len(), 1); - let delta_type = result.fields()[0].data_type(); + let fields = result.fields().collect_vec(); + assert_eq!(fields.len(), 1); + let delta_type = fields[0].data_type(); assert_eq!(delta_type, &DeltaDataType::STRING); let mut expected_meta = HashMap::new(); expected_meta.insert("a".to_string(), MetadataValue::String("a1".to_string())); expected_meta.insert("b".to_string(), MetadataValue::String("b2".to_string())); - assert_eq!(result.fields()[0].metadata(), &expected_meta); + assert_eq!(fields[0].metadata(), &expected_meta); } #[test] diff --git a/crates/core/src/operations/convert_to_delta.rs b/crates/core/src/operations/convert_to_delta.rs index 2e157c38c0..a51d353b20 100644 --- a/crates/core/src/operations/convert_to_delta.rs +++ b/crates/core/src/operations/convert_to_delta.rs @@ -1,9 +1,24 @@ //! Command for converting a Parquet table to a Delta table in place // https://github.com/delta-io/delta/blob/1d5dd774111395b0c4dc1a69c94abc169b1c83b6/spark/src/main/scala/org/apache/spark/sql/delta/commands/ConvertToDeltaCommand.scala +use std::collections::{HashMap, HashSet}; +use std::num::TryFromIntError; +use std::str::{FromStr, Utf8Error}; +use std::sync::Arc; + +use arrow::{datatypes::Schema as ArrowSchema, error::ArrowError}; +use futures::future::{self, BoxFuture}; +use futures::TryStreamExt; +use indexmap::IndexMap; +use itertools::Itertools; +use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStreamBuilder}; +use parquet::errors::ParquetError; +use percent_encoding::percent_decode_str; +use serde_json::{Map, Value}; +use tracing::debug; use crate::operations::get_num_idx_cols_and_stats_columns; use crate::{ - kernel::{Add, DataType, Schema, StructField}, + kernel::{scalars::ScalarExt, Add, DataType, Schema, StructField}, logstore::{LogStore, LogStoreRef}, operations::create::CreateBuilder, protocol::SaveMode, @@ -12,25 +27,6 @@ use crate::{ writer::stats::stats_from_parquet_metadata, DeltaResult, DeltaTable, DeltaTableError, ObjectStoreError, NULL_PARTITION_VALUE_DATA_PATH, }; -use arrow::{datatypes::Schema as ArrowSchema, error::ArrowError}; -use futures::{ - future::{self, BoxFuture}, - TryStreamExt, -}; -use indexmap::IndexMap; -use parquet::{ - arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStreamBuilder}, - errors::ParquetError, -}; -use percent_encoding::percent_decode_str; -use serde_json::{Map, Value}; -use std::{ - collections::{HashMap, HashSet}, - num::TryFromIntError, - str::{FromStr, Utf8Error}, - sync::Arc, -}; -use tracing::debug; /// Error converting a Parquet table to a Delta table #[derive(Debug, thiserror::Error)] @@ -52,7 +48,7 @@ enum Error { #[error("The schema of partition columns must be provided to convert a Parquet table to a Delta table")] MissingPartitionSchema, #[error("Partition column provided by the user does not exist in the parquet files")] - PartitionColumnNotExist(HashSet), + PartitionColumnNotExist, #[error("The given location is already a delta table location")] DeltaTableAlready, #[error("Location must be provided to convert a Parquet table to a Delta table")] @@ -104,7 +100,7 @@ pub struct ConvertToDeltaBuilder { log_store: Option, location: Option, storage_options: Option>, - partition_schema: HashSet, + partition_schema: HashMap, partition_strategy: PartitionStrategy, mode: SaveMode, name: Option, @@ -169,7 +165,10 @@ impl ConvertToDeltaBuilder { mut self, partition_schema: impl IntoIterator, ) -> Self { - self.partition_schema = HashSet::from_iter(partition_schema); + self.partition_schema = partition_schema + .into_iter() + .map(|f| (f.name.clone(), f)) + .collect(); self } @@ -276,12 +275,7 @@ impl ConvertToDeltaBuilder { let mut arrow_schemas = Vec::new(); let mut actions = Vec::new(); // partition columns that were defined by caller and are expected to apply on this table - let mut expected_partitions: HashMap = self - .partition_schema - .clone() - .into_iter() - .map(|field| (field.name.clone(), field)) - .collect(); + let mut expected_partitions: HashMap = self.partition_schema.clone(); // A HashSet of all unique partition columns in a Parquet table let mut partition_columns = HashSet::new(); // A vector of StructField of all unique partition columns in a Parquet table @@ -317,12 +311,14 @@ impl ConvertToDeltaBuilder { // Safety: we just checked that the key is present in the map let field = partition_schema_fields.get(key).unwrap(); let scalar = if value == NULL_PARTITION_VALUE_DATA_PATH { - Ok(crate::kernel::Scalar::Null(field.data_type().clone())) + Ok(delta_kernel::expressions::Scalar::Null( + field.data_type().clone(), + )) } else { let decoded = percent_decode_str(value).decode_utf8()?; match field.data_type() { DataType::Primitive(p) => p.parse_scalar(decoded.as_ref()), - _ => Err(crate::kernel::Error::Generic(format!( + _ => Err(delta_kernel::Error::Generic(format!( "Exprected primitive type, found: {:?}", field.data_type() ))), @@ -390,25 +386,19 @@ impl ConvertToDeltaBuilder { if !expected_partitions.is_empty() { // Partition column provided by the user does not exist in the parquet files - return Err(Error::PartitionColumnNotExist(self.partition_schema)); + return Err(Error::PartitionColumnNotExist); } // Merge parquet file schemas // This step is needed because timestamp will not be preserved when copying files in S3. We can't use the schema of the latest parqeut file as Delta table's schema - let mut schema_fields = Schema::try_from(&ArrowSchema::try_merge(arrow_schemas)?)? - .fields() - .clone(); - schema_fields.append( - &mut partition_schema_fields - .values() - .cloned() - .collect::>(), - ); + let schema = Schema::try_from(&ArrowSchema::try_merge(arrow_schemas)?)?; + let mut schema_fields = schema.fields().collect_vec(); + schema_fields.append(&mut partition_schema_fields.values().collect::>()); // Generate CreateBuilder with corresponding add actions, schemas and operation meta let mut builder = CreateBuilder::new() .with_log_store(log_store) - .with_columns(schema_fields) + .with_columns(schema_fields.into_iter().cloned()) .with_partition_columns(partition_columns.into_iter()) .with_actions(actions) .with_save_mode(self.mode) @@ -447,17 +437,20 @@ impl std::future::IntoFuture for ConvertToDeltaBuilder { #[cfg(test)] mod tests { + use std::fs; + + use delta_kernel::expressions::Scalar; + use itertools::Itertools; + use pretty_assertions::assert_eq; + use tempfile::tempdir; + use super::*; use crate::{ - kernel::{DataType, PrimitiveType, Scalar}, + kernel::{DataType, PrimitiveType}, open_table, storage::StorageOptions, Path, }; - use itertools::Itertools; - use pretty_assertions::assert_eq; - use std::fs; - use tempfile::tempdir; fn schema_field(key: &str, primitive: PrimitiveType, nullable: bool) -> StructField { StructField::new(key.to_string(), DataType::Primitive(primitive), nullable) @@ -563,7 +556,8 @@ mod tests { .get_schema() .expect("Failed to get schema") .fields() - .clone(); + .cloned() + .collect_vec(); schema_fields.sort_by(|a, b| a.name().cmp(b.name())); assert_eq!( schema_fields, expected_schema, @@ -603,14 +597,15 @@ mod tests { "part-00000-d22c627d-9655-4153-9527-f8995620fa42-c000.snappy.parquet" ); - let Some(Scalar::Struct(min_values, _)) = action.min_values() else { + let Some(Scalar::Struct(data)) = action.min_values() else { panic!("Missing min values"); }; - assert_eq!(min_values, vec![Scalar::Date(18628), Scalar::Integer(1)]); - let Some(Scalar::Struct(max_values, _)) = action.max_values() else { + assert_eq!(data.values(), vec![Scalar::Date(18628), Scalar::Integer(1)]); + + let Some(Scalar::Struct(data)) = action.max_values() else { panic!("Missing max values"); }; - assert_eq!(max_values, vec![Scalar::Date(18632), Scalar::Integer(5)]); + assert_eq!(data.values(), vec![Scalar::Date(18632), Scalar::Integer(5)]); assert_delta_table( table, diff --git a/crates/core/src/operations/create.rs b/crates/core/src/operations/create.rs index 728358307b..e53ec43c95 100644 --- a/crates/core/src/operations/create.rs +++ b/crates/core/src/operations/create.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use std::sync::Arc; +use delta_kernel::schema::MetadataValue; use futures::future::BoxFuture; use maplit::hashset; use serde_json::Value; @@ -128,7 +129,24 @@ impl CreateBuilder { ) -> Self { let mut field = StructField::new(name.into(), data_type, nullable); if let Some(meta) = metadata { - field = field.with_metadata(meta); + field = field.with_metadata(meta.iter().map(|(k, v)| { + ( + k, + if let Value::Number(n) = v { + n.as_i64().map_or_else( + || MetadataValue::String(v.to_string()), + |i| { + i32::try_from(i) + .ok() + .map(MetadataValue::Number) + .unwrap_or_else(|| MetadataValue::String(v.to_string())) + }, + ) + } else { + MetadataValue::String(v.to_string()) + }, + ) + })); }; self.columns.push(field); self @@ -250,8 +268,7 @@ impl CreateBuilder { }; let configuration = self.configuration; - let contains_timestampntz = PROTOCOL.contains_timestampntz(&self.columns); - + let contains_timestampntz = PROTOCOL.contains_timestampntz(self.columns.iter()); // TODO configure more permissive versions based on configuration. Also how should this ideally be handled? // We set the lowest protocol we can, and if subsequent writes use newer features we update metadata? @@ -390,7 +407,7 @@ mod tests { let table = DeltaOps::new_in_memory() .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_save_mode(SaveMode::Ignore) .await .unwrap(); @@ -410,7 +427,7 @@ mod tests { .await .unwrap() .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_save_mode(SaveMode::Ignore) .await .unwrap(); @@ -428,7 +445,7 @@ mod tests { ); let table = CreateBuilder::new() .with_location(format!("./{relative_path}")) - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .await .unwrap(); assert_eq!(table.version(), 0); @@ -439,7 +456,7 @@ mod tests { let schema = get_delta_schema(); let table = CreateBuilder::new() .with_location("memory://") - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .await .unwrap(); assert_eq!(table.version(), 0); @@ -462,7 +479,7 @@ mod tests { }; let table = CreateBuilder::new() .with_location("memory://") - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .with_actions(vec![Action::Protocol(protocol)]) .await .unwrap(); @@ -471,7 +488,7 @@ mod tests { let table = CreateBuilder::new() .with_location("memory://") - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .with_configuration_property(DeltaConfigKey::AppendOnly, Some("true")) .await .unwrap(); @@ -494,7 +511,7 @@ mod tests { let schema = get_delta_schema(); let table = CreateBuilder::new() .with_location(tmp_dir.path().to_str().unwrap()) - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .await .unwrap(); assert_eq!(table.version(), 0); @@ -505,7 +522,7 @@ mod tests { // Check an error is raised when a table exists at location let table = CreateBuilder::new() .with_log_store(log_store.clone()) - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .with_save_mode(SaveMode::ErrorIfExists) .await; assert!(table.is_err()); @@ -513,7 +530,7 @@ mod tests { // Check current table is returned when ignore option is chosen. let table = CreateBuilder::new() .with_log_store(log_store.clone()) - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .with_save_mode(SaveMode::Ignore) .await .unwrap(); @@ -522,7 +539,7 @@ mod tests { // Check table is overwritten let table = CreateBuilder::new() .with_log_store(log_store) - .with_columns(schema.fields().iter().cloned()) + .with_columns(schema.fields().cloned()) .with_save_mode(SaveMode::Overwrite) .await .unwrap(); @@ -543,7 +560,7 @@ mod tests { let mut table = DeltaOps(table) .create() - .with_columns(schema.fields().iter().cloned()) + .with_columns(schema.fields().cloned()) .with_save_mode(SaveMode::Overwrite) .await .unwrap(); @@ -567,7 +584,7 @@ mod tests { let mut table = DeltaOps(table) .create() - .with_columns(schema.fields().iter().cloned()) + .with_columns(schema.fields().cloned()) .with_save_mode(SaveMode::Overwrite) .with_partition_columns(vec!["id"]) .await @@ -589,7 +606,7 @@ mod tests { // Fail to create table with unknown Delta key let table = CreateBuilder::new() .with_location("memory://") - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .with_configuration(config.clone()) .await; assert!(table.is_err()); @@ -597,7 +614,7 @@ mod tests { // Succeed in creating table with unknown Delta key since we set raise_if_key_not_exists to false let table = CreateBuilder::new() .with_location("memory://") - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .with_raise_if_key_not_exists(false) .with_configuration(config) .await; diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index aba54cd5f1..4653920965 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -76,11 +76,11 @@ pub struct DeleteMetrics { /// Number of rows copied in the process of deleting files pub num_copied_rows: Option, /// Time taken to execute the entire operation - pub execution_time_ms: u128, + pub execution_time_ms: u64, /// Time taken to scan the file for matches - pub scan_time_ms: u128, + pub scan_time_ms: u64, /// Time taken to rewrite the matched files - pub rewrite_time_ms: u128, + pub rewrite_time_ms: u64, } impl super::Operation<()> for DeleteBuilder {} @@ -207,7 +207,7 @@ async fn execute( let scan_start = Instant::now(); let candidates = find_files(&snapshot, log_store.clone(), &state, predicate.clone()).await?; - metrics.scan_time_ms = Instant::now().duration_since(scan_start).as_millis(); + metrics.scan_time_ms = Instant::now().duration_since(scan_start).as_millis() as u64; let predicate = predicate.unwrap_or(Expr::Literal(ScalarValue::Boolean(Some(true)))); @@ -225,7 +225,7 @@ async fn execute( writer_properties, ) .await?; - metrics.rewrite_time_ms = Instant::now().duration_since(write_start).as_millis(); + metrics.rewrite_time_ms = Instant::now().duration_since(write_start).as_millis() as u64; add }; let remove = candidates.candidates; @@ -254,7 +254,7 @@ async fn execute( })) } - metrics.execution_time_ms = Instant::now().duration_since(exec_start).as_millis(); + metrics.execution_time_ms = Instant::now().duration_since(exec_start).as_millis() as u64; commit_properties .app_metadata @@ -355,7 +355,7 @@ mod tests { let table = DeltaOps::new_in_memory() .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_partition_columns(partitions.unwrap_or_default()) .await .unwrap(); diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index c13da4d879..6c783bc9b4 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -1595,7 +1595,7 @@ mod tests { let table = DeltaOps::new_in_memory() .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_partition_columns(partitions.unwrap_or_default()) .await .unwrap(); diff --git a/crates/core/src/operations/optimize.rs b/crates/core/src/operations/optimize.rs index 10cbb6a22a..9e1641fc7f 100644 --- a/crates/core/src/operations/optimize.rs +++ b/crates/core/src/operations/optimize.rs @@ -27,6 +27,7 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use arrow::datatypes::SchemaRef as ArrowSchemaRef; use arrow_array::RecordBatch; +use delta_kernel::expressions::Scalar; use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{Future, StreamExt, TryStreamExt}; @@ -43,7 +44,7 @@ use tracing::debug; use super::transaction::PROTOCOL; use super::writer::{PartitionWriter, PartitionWriterConfig}; use crate::errors::{DeltaResult, DeltaTableError}; -use crate::kernel::{Action, PartitionsExt, Remove, Scalar}; +use crate::kernel::{scalars::ScalarExt, Action, PartitionsExt, Remove}; use crate::logstore::LogStoreRef; use crate::operations::transaction::{CommitBuilder, CommitProperties, DEFAULT_RETRIES}; use crate::protocol::DeltaOperation; @@ -1001,7 +1002,6 @@ fn build_zorder_plan( let field_names = snapshot .schema() .fields() - .iter() .map(|field| field.name().to_string()) .collect_vec(); let unknown_columns = zorder_columns diff --git a/crates/core/src/operations/transaction/protocol.rs b/crates/core/src/operations/transaction/protocol.rs index ac5bab7738..d1ab9269bc 100644 --- a/crates/core/src/operations/transaction/protocol.rs +++ b/crates/core/src/operations/transaction/protocol.rs @@ -81,20 +81,19 @@ impl ProtocolChecker { } /// checks if table contains timestamp_ntz in any field including nested fields. - pub fn contains_timestampntz(&self, fields: &[StructField]) -> bool { - fn check_vec_fields(fields: &[StructField]) -> bool { - fields.iter().any(|f| _check_type(f.data_type())) - } - + pub fn contains_timestampntz<'a>( + &self, + mut fields: impl Iterator, + ) -> bool { fn _check_type(dtype: &DataType) -> bool { match dtype { - &DataType::TIMESTAMPNTZ => true, + &DataType::TIMESTAMP_NTZ => true, DataType::Array(inner) => _check_type(inner.element_type()), - DataType::Struct(inner) => check_vec_fields(inner.fields()), + DataType::Struct(inner) => inner.fields().any(|f| _check_type(f.data_type())), _ => false, } } - check_vec_fields(fields) + fields.any(|f| _check_type(f.data_type())) } /// Check can write_timestamp_ntz @@ -164,7 +163,7 @@ impl ProtocolChecker { if (4..7).contains(&min_writer_version) { debug!("min_writer_version is less 4-6, checking for unsupported table features"); if let Ok(schema) = snapshot.metadata().schema() { - for field in schema.fields.iter() { + for field in schema.fields() { if field.metadata.contains_key( crate::kernel::ColumnMetadataKey::GenerationExpression.as_ref(), ) { diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 9ec8519b9b..31946d104e 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -569,7 +569,7 @@ mod tests { let table = DeltaOps::new_in_memory() .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_partition_columns(partitions.unwrap_or_default()) .await .unwrap(); @@ -859,7 +859,7 @@ mod tests { let table = DeltaOps::new_in_memory() .create() - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .await .unwrap(); let table = write_batch(table, batch).await; diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index c435a3df08..1cdf2780bd 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -320,7 +320,7 @@ impl WriteBuilder { }?; let mut builder = CreateBuilder::new() .with_log_store(self.log_store.clone()) - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .with_configuration(self.configuration.clone()); if let Some(partition_columns) = self.partition_columns.as_ref() { builder = builder.with_partition_columns(partition_columns.clone()) @@ -979,7 +979,7 @@ mod tests { let table = DeltaOps::new_in_memory() .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .await .unwrap(); assert_eq!(table.version(), 0); @@ -1242,7 +1242,7 @@ mod tests { assert_eq!(table.version(), 1); let new_schema = table.metadata().unwrap().schema().unwrap(); let fields = new_schema.fields(); - let names = fields.iter().map(|f| f.name()).collect::>(); + let names = fields.map(|f| f.name()).collect::>(); assert_eq!(names, vec!["id", "value", "modified", "inserted_by"]); } @@ -1300,7 +1300,7 @@ mod tests { assert_eq!(table.version(), 1); let new_schema = table.metadata().unwrap().schema().unwrap(); let fields = new_schema.fields(); - let mut names = fields.iter().map(|f| f.name()).collect::>(); + let mut names = fields.map(|f| f.name()).collect::>(); names.sort(); assert_eq!(names, vec!["id", "inserted_by", "modified", "value"]); let part_cols = table.metadata().unwrap().partition_columns.clone(); @@ -1417,7 +1417,7 @@ mod tests { let table = DeltaOps::new_in_memory() .create() .with_save_mode(SaveMode::ErrorIfExists) - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .await .unwrap(); assert_eq!(table.version(), 0); @@ -1439,7 +1439,7 @@ mod tests { let table = DeltaOps::new_in_memory() .create() .with_save_mode(SaveMode::ErrorIfExists) - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .await .unwrap(); assert_eq!(table.version(), 0); @@ -1455,7 +1455,7 @@ mod tests { let table = DeltaOps::new_in_memory() .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .await .unwrap(); assert_eq!(table.version(), 0); diff --git a/crates/core/src/operations/writer.rs b/crates/core/src/operations/writer.rs index f04d68e412..e5e6901608 100644 --- a/crates/core/src/operations/writer.rs +++ b/crates/core/src/operations/writer.rs @@ -6,6 +6,7 @@ use arrow::datatypes::SchemaRef as ArrowSchemaRef; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use bytes::Bytes; +use delta_kernel::expressions::Scalar; use indexmap::IndexMap; use object_store::{path::Path, ObjectStore}; use parquet::arrow::ArrowWriter; @@ -15,7 +16,7 @@ use tracing::debug; use crate::crate_version; use crate::errors::{DeltaResult, DeltaTableError}; -use crate::kernel::{Add, PartitionsExt, Scalar}; +use crate::kernel::{Add, PartitionsExt}; use crate::storage::ObjectStoreRef; use crate::writer::record_batch::{divide_by_partition_values, PartitionResult}; use crate::writer::stats::create_add; diff --git a/crates/core/src/protocol/checkpoints.rs b/crates/core/src/protocol/checkpoints.rs index 67994c5e49..6bf19a81f5 100644 --- a/crates/core/src/protocol/checkpoints.rs +++ b/crates/core/src/protocol/checkpoints.rs @@ -8,6 +8,7 @@ use arrow_schema::ArrowError; use chrono::{Datelike, NaiveDate, NaiveDateTime, Utc}; use futures::{StreamExt, TryStreamExt}; +use itertools::Itertools; use lazy_static::lazy_static; use object_store::{Error, ObjectStore}; use parquet::arrow::ArrowWriter; @@ -259,7 +260,8 @@ fn parquet_bytes_from_state( // Collect a map of paths that require special stats conversion. let mut stats_conversions: Vec<(SchemaPath, DataType)> = Vec::new(); - collect_stats_conversions(&mut stats_conversions, schema.fields().as_slice()); + let fields = schema.fields().collect_vec(); + collect_stats_conversions(&mut stats_conversions, fields.as_slice()); // if any, tombstones do not include extended file metadata, we must omit the extended metadata fields from the remove schema // See https://github.com/delta-io/delta/blob/master/PROTOCOL.md#add-file-and-remove-file @@ -477,7 +479,7 @@ fn typed_partition_value_from_option_string( } } -fn collect_stats_conversions(paths: &mut Vec<(SchemaPath, DataType)>, fields: &[StructField]) { +fn collect_stats_conversions(paths: &mut Vec<(SchemaPath, DataType)>, fields: &[&StructField]) { let mut _path = SchemaPath::new(); fields .iter() @@ -498,9 +500,7 @@ fn collect_field_conversion( DataType::Struct(struct_field) => { let struct_fields = struct_field.fields(); current_path.push(field.name().to_owned()); - struct_fields - .iter() - .for_each(|f| collect_field_conversion(current_path, all_paths, f)); + struct_fields.for_each(|f| collect_field_conversion(current_path, all_paths, f)); current_path.pop(); } _ => { /* noop */ } @@ -560,7 +560,7 @@ mod tests { let table = DeltaOps::new_in_memory() .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_save_mode(crate::protocol::SaveMode::Ignore) .await .unwrap(); @@ -592,7 +592,7 @@ mod tests { let mut table = DeltaOps::new_in_memory() .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_save_mode(crate::protocol::SaveMode::Ignore) .await .unwrap(); @@ -668,7 +668,7 @@ mod tests { let table = DeltaOps::new_in_memory() .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_save_mode(crate::protocol::SaveMode::Ignore) .await .unwrap(); @@ -802,9 +802,8 @@ mod tests { #[test] fn collect_stats_conversions_test() { let delta_schema: StructType = serde_json::from_value(SCHEMA.clone()).unwrap(); - let fields = delta_schema.fields(); + let fields = delta_schema.fields().collect_vec(); let mut paths = Vec::new(); - collect_stats_conversions(&mut paths, fields.as_slice()); assert_eq!(2, paths.len()); diff --git a/crates/core/src/schema/partitions.rs b/crates/core/src/schema/partitions.rs index c766c1d630..d2b2e84979 100644 --- a/crates/core/src/schema/partitions.rs +++ b/crates/core/src/schema/partitions.rs @@ -1,12 +1,13 @@ //! Delta Table partition handling logic. -//! + +use delta_kernel::expressions::Scalar; use serde::{Serialize, Serializer}; use std::cmp::Ordering; use std::collections::HashMap; use std::convert::TryFrom; use crate::errors::DeltaTableError; -use crate::kernel::{DataType, PrimitiveType, Scalar}; +use crate::kernel::{scalars::ScalarExt, DataType, PrimitiveType}; /// A special value used in Hive to represent the null partition in partitioned tables pub const NULL_PARTITION_VALUE_DATA_PATH: &str = "__HIVE_DEFAULT_PARTITION__"; @@ -32,6 +33,42 @@ pub enum PartitionValue { NotIn(Vec), } +#[derive(Clone, Debug, PartialEq)] +struct ScalarHelper<'a>(&'a Scalar); + +impl PartialOrd for ScalarHelper<'_> { + fn partial_cmp(&self, other: &Self) -> Option { + use Scalar::*; + match (self.0, other.0) { + (Null(_), Null(_)) => Some(Ordering::Equal), + (Integer(a), Integer(b)) => a.partial_cmp(b), + (Long(a), Long(b)) => a.partial_cmp(b), + (Short(a), Short(b)) => a.partial_cmp(b), + (Byte(a), Byte(b)) => a.partial_cmp(b), + (Float(a), Float(b)) => a.partial_cmp(b), + (Double(a), Double(b)) => a.partial_cmp(b), + (String(a), String(b)) => a.partial_cmp(b), + (Boolean(a), Boolean(b)) => a.partial_cmp(b), + (Timestamp(a), Timestamp(b)) => a.partial_cmp(b), + (TimestampNtz(a), TimestampNtz(b)) => a.partial_cmp(b), + (Date(a), Date(b)) => a.partial_cmp(b), + (Binary(a), Binary(b)) => a.partial_cmp(b), + (Decimal(a, p1, s1), Decimal(b, p2, s2)) => { + // TODO implement proper decimal comparison + if p1 != p2 || s1 != s2 { + return None; + }; + a.partial_cmp(b) + } + // TODO should we make an assumption about the ordering of nulls? + // rigth now this is only used for internal purposes. + (Null(_), _) => Some(Ordering::Less), + (_, Null(_)) => Some(Ordering::Greater), + _ => None, + } + } +} + /// A Struct used for filtering a DeltaTable partition by key and value. #[derive(Clone, Debug, PartialEq, Eq)] pub struct PartitionFilter { @@ -49,7 +86,7 @@ fn compare_typed_value( match data_type { DataType::Primitive(primitive_type) => { let other = primitive_type.parse_scalar(filter_value).ok()?; - partition_value.partial_cmp(&other) + ScalarHelper(partition_value).partial_cmp(&ScalarHelper(&other)) } // NOTE: complex types are not supported as partition columns _ => None, diff --git a/crates/core/src/table/config.rs b/crates/core/src/table/config.rs index 05fb0c53ca..3512e3abb5 100644 --- a/crates/core/src/table/config.rs +++ b/crates/core/src/table/config.rs @@ -2,12 +2,12 @@ use std::time::Duration; use std::{collections::HashMap, str::FromStr}; +use delta_kernel::column_mapping::ColumnMappingMode; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; -use crate::errors::DeltaTableError; - use super::Constraint; +use crate::errors::DeltaTableError; /// Typed property keys that can be defined on a delta table /// @@ -463,49 +463,6 @@ impl FromStr for CheckpointPolicy { } } -#[derive(Serialize, Deserialize, Debug, Copy, Clone, PartialEq)] -/// The Column Mapping modes used for reading and writing data -#[serde(rename_all = "camelCase")] -pub enum ColumnMappingMode { - /// No column mapping is applied - None, - /// Columns are mapped by their field_id in parquet - Id, - /// Columns are mapped to a physical name - Name, -} - -impl Default for ColumnMappingMode { - fn default() -> Self { - Self::None - } -} - -impl AsRef for ColumnMappingMode { - fn as_ref(&self) -> &str { - match self { - Self::None => "none", - Self::Id => "id", - Self::Name => "name", - } - } -} - -impl FromStr for ColumnMappingMode { - type Err = DeltaTableError; - - fn from_str(s: &str) -> Result { - match s.to_ascii_lowercase().as_str() { - "none" => Ok(Self::None), - "id" => Ok(Self::Id), - "name" => Ok(Self::Name), - _ => Err(DeltaTableError::Generic( - "Invalid string for ColumnMappingMode".into(), - )), - } - } -} - const SECONDS_PER_MINUTE: u64 = 60; const SECONDS_PER_HOUR: u64 = 60 * SECONDS_PER_MINUTE; const SECONDS_PER_DAY: u64 = 24 * SECONDS_PER_HOUR; diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index 4b818513b0..969f470bfb 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -163,7 +163,6 @@ pub(crate) fn get_partition_col_data_types<'a>( // When loading `partitionValues_parsed` we have to convert the stringified partition values back to the correct data type. schema .fields() - .iter() .filter_map(|f| { if metadata .partition_columns diff --git a/crates/core/src/table/state_arrow.rs b/crates/core/src/table/state_arrow.rs index fe35787cb4..24d4a474ff 100644 --- a/crates/core/src/table/state_arrow.rs +++ b/crates/core/src/table/state_arrow.rs @@ -14,9 +14,9 @@ use arrow_array::{ StringArray, StructArray, TimestampMicrosecondArray, TimestampMillisecondArray, }; use arrow_schema::{DataType, Field, Fields, TimeUnit}; +use delta_kernel::column_mapping::ColumnMappingMode; use itertools::Itertools; -use super::config::ColumnMappingMode; use super::state::DeltaTableState; use crate::errors::DeltaTableError; use crate::kernel::{Add, DataType as DeltaDataType, StructType}; @@ -149,7 +149,13 @@ impl DeltaTableState { .map( |name| -> Result { let schema = metadata.schema()?; - let field = schema.field_with_name(name)?; + let field = + schema + .field(name) + .ok_or(DeltaTableError::MetadataError(format!( + "Invalid partition column {0}", + name + )))?; Ok(field.data_type().try_into()?) }, ) @@ -173,12 +179,12 @@ impl DeltaTableState { .map(|name| -> Result<_, DeltaTableError> { let physical_name = self .schema() - .field_with_name(name) - .or(Err(DeltaTableError::MetadataError(format!( + .field(name) + .ok_or(DeltaTableError::MetadataError(format!( "Invalid partition column {0}", name - ))))? - .physical_name()? + )))? + .physical_name(column_mapping_mode)? .to_string(); Ok((physical_name, name.as_str())) }) @@ -674,7 +680,6 @@ impl<'a> SchemaLeafIterator<'a> { SchemaLeafIterator { fields_remaining: schema .fields() - .iter() .map(|field| (vec![field.name().as_ref()], field.data_type())) .collect(), } diff --git a/crates/core/src/writer/json.rs b/crates/core/src/writer/json.rs index d97d3ef16c..ab1ccac5f2 100644 --- a/crates/core/src/writer/json.rs +++ b/crates/core/src/writer/json.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use arrow::datatypes::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; use arrow::record_batch::*; use bytes::Bytes; +use delta_kernel::expressions::Scalar; use indexmap::IndexMap; use object_store::path::Path; use object_store::ObjectStore; @@ -24,7 +25,7 @@ use super::utils::{ }; use super::{DeltaWriter, DeltaWriterError, WriteMode}; use crate::errors::DeltaTableError; -use crate::kernel::{Add, PartitionsExt, Scalar, StructType}; +use crate::kernel::{scalars::ScalarExt, Add, PartitionsExt, StructType}; use crate::storage::ObjectStoreRetryExt; use crate::table::builder::DeltaTableBuilder; use crate::table::config::DEFAULT_NUM_INDEX_COLS; @@ -616,7 +617,7 @@ mod tests { .with_location(&path) .with_table_name("test-table") .with_comment("A table for running tests") - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .await .unwrap(); table.load().await.expect("Failed to load table"); diff --git a/crates/core/src/writer/record_batch.rs b/crates/core/src/writer/record_batch.rs index c21435dd14..9cdc6a4322 100644 --- a/crates/core/src/writer/record_batch.rs +++ b/crates/core/src/writer/record_batch.rs @@ -14,6 +14,7 @@ use arrow_array::ArrayRef; use arrow_row::{RowConverter, SortField}; use arrow_schema::{ArrowError, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; use bytes::Bytes; +use delta_kernel::expressions::Scalar; use indexmap::IndexMap; use object_store::{path::Path, ObjectStore}; use parquet::{arrow::ArrowWriter, errors::ParquetError}; @@ -28,7 +29,7 @@ use super::utils::{ }; use super::{DeltaWriter, DeltaWriterError, WriteMode}; use crate::errors::DeltaTableError; -use crate::kernel::{Action, Add, PartitionsExt, Scalar, StructType}; +use crate::kernel::{scalars::ScalarExt, Action, Add, PartitionsExt, StructType}; use crate::operations::cast::merge_schema; use crate::storage::ObjectStoreRetryExt; use crate::table::builder::DeltaTableBuilder; @@ -539,7 +540,7 @@ mod tests { let table = DeltaOps(table) .create() .with_partition_columns(partition_cols.to_vec()) - .with_columns(delta_schema.fields().clone()) + .with_columns(delta_schema.fields().cloned()) .await .unwrap(); @@ -659,7 +660,7 @@ mod tests { .with_location(table_path.to_str().unwrap()) .with_table_name("test-table") .with_comment("A table for running tests") - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_partition_columns(partition_cols) .await .unwrap(); @@ -735,7 +736,7 @@ mod tests { .with_location(table_path.to_str().unwrap()) .with_table_name("test-table") .with_comment("A table for running tests") - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .await .unwrap(); table.load().await.expect("Failed to load table"); @@ -779,8 +780,7 @@ mod tests { let new_schema = table.metadata().unwrap().schema().unwrap(); let expected_columns = vec!["id", "value", "modified", "vid", "name"]; - let found_columns: Vec<&String> = - new_schema.fields().iter().map(|f| f.name()).collect(); + let found_columns: Vec<&String> = new_schema.fields().map(|f| f.name()).collect(); assert_eq!( expected_columns, found_columns, "The new table schema does not contain all evolved columns as expected" @@ -797,7 +797,7 @@ mod tests { .with_location(table_path.to_str().unwrap()) .with_table_name("test-table") .with_comment("A table for running tests") - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_partition_columns(["id"]) .await .unwrap(); @@ -928,7 +928,7 @@ mod tests { .with_location(table_path.to_str().unwrap()) .with_table_name("test-table") .with_comment("A table for running tests") - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .await .unwrap(); table.load().await.expect("Failed to load table"); diff --git a/crates/core/src/writer/stats.rs b/crates/core/src/writer/stats.rs index 0cea01ee6a..28a089ae1c 100644 --- a/crates/core/src/writer/stats.rs +++ b/crates/core/src/writer/stats.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use std::{collections::HashMap, ops::AddAssign}; +use delta_kernel::expressions::Scalar; use indexmap::IndexMap; use parquet::file::metadata::ParquetMetaData; use parquet::format::FileMetaData; @@ -14,7 +15,7 @@ use parquet::{ }; use super::*; -use crate::kernel::{Add, Scalar}; +use crate::kernel::{scalars::ScalarExt, Add}; use crate::protocol::{ColumnValueStat, Stats}; /// Creates an [`Add`] log action struct. diff --git a/crates/core/src/writer/test_utils.rs b/crates/core/src/writer/test_utils.rs index 093ad7cbd0..ff860ed1cf 100644 --- a/crates/core/src/writer/test_utils.rs +++ b/crates/core/src/writer/test_utils.rs @@ -276,7 +276,7 @@ pub async fn setup_table_with_configuration( let table_schema = get_delta_schema(); DeltaOps::new_in_memory() .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_configuration_property(key, value) .await .expect("Failed to create table") @@ -299,7 +299,7 @@ pub async fn create_initialized_table(partition_cols: &[String]) -> DeltaTable { .with_location(table_path.to_str().unwrap()) .with_table_name("test-table") .with_comment("A table for running tests") - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_partition_columns(partition_cols) .await .unwrap() diff --git a/crates/core/tests/command_merge.rs b/crates/core/tests/command_merge.rs index 59a941a24f..10855aa0a8 100644 --- a/crates/core/tests/command_merge.rs +++ b/crates/core/tests/command_merge.rs @@ -19,7 +19,7 @@ async fn create_table(table_uri: &str, partition: Option>) -> DeltaTab let ops = DeltaOps::try_from_uri(table_uri).await.unwrap(); let table = ops .create() - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_partition_columns(partition.unwrap_or_default()) .await .expect("Failed to create table"); diff --git a/crates/core/tests/command_optimize.rs b/crates/core/tests/command_optimize.rs index 4f26c55fd4..13cbd168e4 100644 --- a/crates/core/tests/command_optimize.rs +++ b/crates/core/tests/command_optimize.rs @@ -249,7 +249,7 @@ async fn test_optimize_with_partitions() -> Result<(), Box> { let partition_values = partition_adds[0].partition_values()?; assert_eq!( partition_values.get("date"), - Some(&deltalake_core::kernel::Scalar::String( + Some(&delta_kernel::expressions::Scalar::String( "2022-05-22".to_string() )) ); diff --git a/crates/core/tests/fs_common/mod.rs b/crates/core/tests/fs_common/mod.rs index 3ef7c82edf..e3d9e722e4 100644 --- a/crates/core/tests/fs_common/mod.rs +++ b/crates/core/tests/fs_common/mod.rs @@ -55,7 +55,7 @@ pub async fn create_test_table( .with_location(path) .with_table_name("test-table") .with_comment("A table for running tests") - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .with_partition_columns(partition_columns) .with_configuration(config) .await diff --git a/crates/core/tests/integration_datafusion.rs b/crates/core/tests/integration_datafusion.rs index 64d80e3bce..cb3cc41edb 100644 --- a/crates/core/tests/integration_datafusion.rs +++ b/crates/core/tests/integration_datafusion.rs @@ -1,14 +1,10 @@ #![cfg(feature = "datafusion")] - -use arrow::array::Int64Array; -use deltalake_test::datafusion::*; -use deltalake_test::utils::*; -use serial_test::serial; - use std::collections::{HashMap, HashSet}; +use std::error::Error; use std::path::PathBuf; use std::sync::Arc; +use arrow::array::Int64Array; use arrow::array::*; use arrow::record_batch::RecordBatch; use arrow_schema::{ @@ -28,8 +24,6 @@ use datafusion_expr::Expr; use datafusion_proto::bytes::{ physical_plan_from_bytes_with_extension_codec, physical_plan_to_bytes_with_extension_codec, }; -use url::Url; - use deltalake_core::delta_datafusion::{DeltaPhysicalCodec, DeltaScan}; use deltalake_core::kernel::{DataType, MapType, PrimitiveType, StructField, StructType}; use deltalake_core::logstore::logstore_for; @@ -41,7 +35,10 @@ use deltalake_core::{ operations::{write::WriteBuilder, DeltaOps}, DeltaTable, DeltaTableError, }; -use std::error::Error; +use deltalake_test::datafusion::*; +use deltalake_test::utils::*; +use serial_test::serial; +use url::Url; mod local { use datafusion::common::stats::Precision; @@ -106,7 +103,7 @@ mod local { .unwrap() .create() .with_save_mode(SaveMode::Ignore) - .with_columns(table_schema.fields().clone()) + .with_columns(table_schema.fields().cloned()) .with_partition_columns(partitions) .await .unwrap(); @@ -198,10 +195,8 @@ mod local { &ctx, &DeltaPhysicalCodec {}, )?; - let fields = StructType::try_from(source_scan.schema()) - .unwrap() - .fields() - .clone(); + let schema = StructType::try_from(source_scan.schema()).unwrap(); + let fields = schema.fields().cloned(); // Create target Delta Table let target_table = CreateBuilder::new() @@ -1035,7 +1030,7 @@ mod local { deltalake_core::DeltaTableBuilder::from_uri("./tests/data/issue-1619").build()?; let _ = DeltaOps::from(table) .create() - .with_columns(schema.fields().to_owned()) + .with_columns(schema.fields().cloned()) .await?; let mut table = open_table("./tests/data/issue-1619").await?; diff --git a/crates/test/src/concurrent.rs b/crates/test/src/concurrent.rs index dc4f3168e3..d028917a1e 100644 --- a/crates/test/src/concurrent.rs +++ b/crates/test/src/concurrent.rs @@ -34,7 +34,7 @@ async fn prepare_table( let table = DeltaOps(table) .create() - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .await?; assert_eq!(0, table.version()); diff --git a/crates/test/src/lib.rs b/crates/test/src/lib.rs index aedb24844c..0a1ca39539 100644 --- a/crates/test/src/lib.rs +++ b/crates/test/src/lib.rs @@ -86,7 +86,7 @@ impl TestContext { .with_log_store(log_store) .with_table_name("delta-rs_test_table") .with_comment("Table created by delta-rs tests") - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .with_partition_columns(p) .await .unwrap() diff --git a/python/Cargo.toml b/python/Cargo.toml index f5b4cf5b5b..3938bd0aa9 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -15,6 +15,8 @@ crate-type = ["cdylib"] doc = false [dependencies] +delta_kernel.workspace = true + # arrow arrow-schema = { workspace = true, features = ["serde"] } diff --git a/python/src/lib.rs b/python/src/lib.rs index b7fc21f3cf..14cbf6f916 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -14,6 +14,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use arrow::pyarrow::PyArrowType; use chrono::{DateTime, Duration, FixedOffset, Utc}; +use delta_kernel::expressions::Scalar; use deltalake::arrow::compute::concat_batches; use deltalake::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; use deltalake::arrow::record_batch::RecordBatchReader; @@ -26,7 +27,9 @@ use deltalake::datafusion::physical_plan::ExecutionPlan; use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; use deltalake::errors::DeltaTableError; -use deltalake::kernel::{Action, Add, Invariant, LogicalFile, Remove, Scalar, StructType}; +use deltalake::kernel::{ + scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType, +}; use deltalake::operations::collect_sendable_stream; use deltalake::operations::constraints::ConstraintBuilder; use deltalake::operations::convert_to_delta::{ConvertToDeltaBuilder, PartitionStrategy}; @@ -949,7 +952,6 @@ impl RawDeltaTable { .get_schema() .map_err(|_| DeltaProtocolError::new_err("table does not yet have a schema"))? .fields() - .iter() .map(|field| field.name().as_str()) .collect(); let partition_columns: HashSet<&str> = self @@ -1362,9 +1364,9 @@ fn scalar_to_py(value: &Scalar, py_date: &PyAny, py: Python) -> PyResult value.serialize().to_object(py), - Struct(values, fields) => { + Struct(data) => { let py_struct = PyDict::new(py); - for (field, value) in fields.iter().zip(values.iter()) { + for (field, value) in data.fields().iter().zip(data.values().iter()) { py_struct.set_item(field.name(), scalar_to_py(value, py_date, py)?)?; } py_struct.to_object(py) @@ -1431,8 +1433,8 @@ fn filestats_to_expression_next<'py>( let mut has_nulls_set: HashSet = HashSet::new(); // NOTE: null_counts should always return a struct scalar. - if let Some(Scalar::Struct(values, fields)) = file_info.null_counts() { - for (field, value) in fields.iter().zip(values.iter()) { + if let Some(Scalar::Struct(data)) = file_info.null_counts() { + for (field, value) in data.fields().iter().zip(data.values().iter()) { if let Scalar::Long(val) = value { if *val == 0 { expressions.push(py_field.call1((field.name(),))?.call_method0("is_valid")); @@ -1446,11 +1448,11 @@ fn filestats_to_expression_next<'py>( } // NOTE: min_values should always return a struct scalar. - if let Some(Scalar::Struct(values, fields)) = file_info.min_values() { - for (field, value) in fields.iter().zip(values.iter()) { + if let Some(Scalar::Struct(data)) = file_info.min_values() { + for (field, value) in data.fields().iter().zip(data.values().iter()) { match value { // TODO: Handle nested field statistics. - Scalar::Struct(_, _) => {} + Scalar::Struct(_) => {} _ => { let maybe_minimum = cast_to_type(field.name(), scalar_to_py(value, py_date, py)?, &schema.0); @@ -1473,11 +1475,11 @@ fn filestats_to_expression_next<'py>( } // NOTE: max_values should always return a struct scalar. - if let Some(Scalar::Struct(values, fields)) = file_info.max_values() { - for (field, value) in fields.iter().zip(values.iter()) { + if let Some(Scalar::Struct(data)) = file_info.max_values() { + for (field, value) in data.fields().iter().zip(data.values().iter()) { match value { // TODO: Handle nested field statistics. - Scalar::Struct(_, _) => {} + Scalar::Struct(_) => {} _ => { let maybe_maximum = cast_to_type(field.name(), scalar_to_py(value, py_date, py)?, &schema.0); @@ -1668,7 +1670,7 @@ fn create_deltalake( let mut builder = DeltaOps(table) .create() - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .with_save_mode(mode) .with_raise_if_key_not_exists(raise_if_key_not_exists) .with_partition_columns(partition_by); @@ -1723,7 +1725,7 @@ fn write_new_deltalake( let mut builder = DeltaOps(table) .create() - .with_columns(schema.fields().clone()) + .with_columns(schema.fields().cloned()) .with_partition_columns(partition_by) .with_actions(add_actions.iter().map(|add| Action::Add(add.into()))); @@ -1770,7 +1772,7 @@ fn convert_to_deltalake( if let Some(part_schema) = partition_schema { let schema: StructType = (&part_schema.0).try_into().map_err(PythonError::from)?; - builder = builder.with_partition_schema(schema.fields().clone()); + builder = builder.with_partition_schema(schema.fields().cloned()); } if let Some(partition_strategy) = &partition_strategy { diff --git a/python/src/schema.rs b/python/src/schema.rs index c4a250a57a..36f301ab98 100644 --- a/python/src/schema.rs +++ b/python/src/schema.rs @@ -7,8 +7,8 @@ use deltalake::arrow::datatypes::{ use deltalake::arrow::error::ArrowError; use deltalake::arrow::pyarrow::PyArrowType; use deltalake::kernel::{ - ArrayType as DeltaArrayType, DataType, MapType as DeltaMapType, PrimitiveType as DeltaPrimitve, - StructField, StructType as DeltaStructType, + ArrayType as DeltaArrayType, DataType, MapType as DeltaMapType, MetadataValue, + PrimitiveType as DeltaPrimitve, StructField, StructType as DeltaStructType, StructTypeExt, }; use pyo3::exceptions::{PyException, PyNotImplementedError, PyTypeError, PyValueError}; use pyo3::prelude::*; @@ -98,30 +98,6 @@ impl PrimitiveType { Ok(Self { inner_type: data_type, }) - - // if data_type.starts_with("decimal") { - // if try_parse_decimal_type(&data_type).is_none() { - // Err(PyValueError::new_err(format!( - // "invalid decimal type: {data_type}" - // ))) - // } else { - // Ok(Self { - // inner_type: data_type, - // }) - // } - // } else if !VALID_PRIMITIVE_TYPES - // .iter() - // .any(|&valid| data_type == valid) - // { - // Err(PyValueError::new_err(format!( - // "data_type must be one of decimal(, ), {}.", - // VALID_PRIMITIVE_TYPES.join(", ") - // ))) - // } else { - // Ok(Self { - // inner_type: data_type, - // }) - // } } #[getter] @@ -145,7 +121,7 @@ impl PrimitiveType { #[pyo3(text_signature = "($self)")] fn to_json(&self) -> PyResult { - let inner_type = DataType::Primitive(self.inner_type); + let inner_type = DataType::Primitive(self.inner_type.clone()); serde_json::to_string(&inner_type).map_err(|err| PyException::new_err(err.to_string())) } @@ -160,7 +136,7 @@ impl PrimitiveType { #[pyo3(text_signature = "($self)")] fn to_pyarrow(&self) -> PyResult> { - let inner_type = DataType::Primitive(self.inner_type); + let inner_type = DataType::Primitive(self.inner_type.clone()); Ok(PyArrowType((&inner_type).try_into().map_err( |err: ArrowError| PyException::new_err(err.to_string()), )?)) @@ -455,7 +431,24 @@ impl Field { }; let mut inner = StructField::new(name, ty, nullable); - inner = inner.with_metadata(metadata); + inner = inner.with_metadata(metadata.iter().map(|(k, v)| { + ( + k, + if let serde_json::Value::Number(n) = v { + n.as_i64().map_or_else( + || MetadataValue::String(v.to_string()), + |i| { + i32::try_from(i) + .ok() + .map(MetadataValue::Number) + .unwrap_or_else(|| MetadataValue::String(v.to_string())) + }, + ) + } else { + MetadataValue::String(v.to_string()) + }, + ) + })); Ok(Self { inner }) } @@ -597,7 +590,6 @@ impl StructType { let inner_data: Vec = self .inner_type .fields() - .iter() .map(|field| { let field = Field { inner: field.clone(), @@ -628,7 +620,6 @@ impl StructType { fn fields(&self) -> Vec { self.inner_type .fields() - .iter() .map(|field| Field { inner: field.clone(), }) @@ -672,7 +663,6 @@ impl StructType { pub fn schema_to_pyobject(schema: &DeltaStructType, py: Python) -> PyResult { let fields: Vec = schema .fields() - .iter() .map(|field| Field { inner: field.clone(), }) @@ -718,7 +708,6 @@ impl PySchema { let inner_data: Vec = super_ .inner_type .fields() - .iter() .map(|field| { let field = Field { inner: field.clone(), diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 1e813318f8..fb41d55a09 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -1458,7 +1458,7 @@ def test_invalid_decimals(tmp_path: pathlib.Path, engine): with pytest.raises( SchemaMismatchError, - match=re.escape("Invalid data type for Delta Lake: decimal(39,1)"), + match=re.escape("Invalid data type for Delta Lake: Decimal256(39, 1)"), ): write_deltalake(table_or_uri=tmp_path, mode="append", data=data, engine=engine)