From f5c2645ac887ec4137a0d18b21aad2c44c0ce276 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 12 Jan 2025 17:47:51 +0100 Subject: [PATCH] feat: generated columns Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> --- crates/core/src/delta_datafusion/mod.rs | 15 +- crates/core/src/kernel/error.rs | 9 + crates/core/src/kernel/models/actions.rs | 101 +++++++++-- crates/core/src/kernel/models/schema.rs | 72 ++++++++ crates/core/src/operations/add_column.rs | 26 +-- crates/core/src/operations/add_feature.rs | 4 +- crates/core/src/operations/cdc.rs | 2 +- crates/core/src/operations/create.rs | 45 ++--- .../src/operations/transaction/protocol.rs | 44 +---- crates/core/src/operations/write.rs | 167 +++++++++++------- crates/core/src/table/mod.rs | 36 ++++ 11 files changed, 358 insertions(+), 163 deletions(-) diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 6aad9ab3f7..08c38ddff3 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -81,7 +81,7 @@ use crate::kernel::{Add, DataCheck, EagerSnapshot, Invariant, Snapshot, StructTy use crate::logstore::LogStoreRef; use crate::table::builder::ensure_table_uri; use crate::table::state::DeltaTableState; -use crate::table::Constraint; +use crate::table::{Constraint, GeneratedColumn}; use crate::{open_table, open_table_with_storage_options, DeltaTable}; pub(crate) const PATH_COLUMN: &str = "__delta_rs_path"; @@ -1159,6 +1159,7 @@ pub(crate) async fn execute_plan_to_batch( pub struct DeltaDataChecker { constraints: Vec, invariants: Vec, + generated_columns: Vec, non_nullable_columns: Vec, ctx: SessionContext, } @@ -1169,6 +1170,7 @@ impl DeltaDataChecker { Self { invariants: vec![], constraints: vec![], + generated_columns: vec![], non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } @@ -1179,6 +1181,7 @@ impl DeltaDataChecker { Self { invariants, constraints: vec![], + generated_columns: vec![], non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } @@ -1189,6 +1192,7 @@ impl DeltaDataChecker { Self { constraints, invariants: vec![], + generated_columns: vec![], non_nullable_columns: vec![], ctx: DeltaSessionContext::default().into(), } @@ -1209,6 +1213,10 @@ impl DeltaDataChecker { /// Create a new DeltaDataChecker pub fn new(snapshot: &DeltaTableState) -> Self { let invariants = snapshot.schema().get_invariants().unwrap_or_default(); + let generated_columns = snapshot + .schema() + .get_generated_columns() + .unwrap_or_default(); let constraints = snapshot.table_config().get_constraints(); let non_nullable_columns = snapshot .schema() @@ -1224,6 +1232,7 @@ impl DeltaDataChecker { Self { invariants, constraints, + generated_columns, non_nullable_columns, ctx: DeltaSessionContext::default().into(), } @@ -1236,7 +1245,9 @@ impl DeltaDataChecker { pub async fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { self.check_nullability(record_batch)?; self.enforce_checks(record_batch, &self.invariants).await?; - self.enforce_checks(record_batch, &self.constraints).await + self.enforce_checks(record_batch, &self.constraints).await?; + self.enforce_checks(record_batch, &self.generated_columns) + .await } /// Return true if all the nullability checks are valid diff --git a/crates/core/src/kernel/error.rs b/crates/core/src/kernel/error.rs index cefe81bf9d..fe34b1d7e4 100644 --- a/crates/core/src/kernel/error.rs +++ b/crates/core/src/kernel/error.rs @@ -65,6 +65,15 @@ pub enum Error { line: String, }, + /// Error returned when the log contains invalid stats JSON. + #[error("Invalid JSON in generation expression, line=`{line}`, err=`{json_err}`")] + InvalidGenerationExpressionJson { + /// JSON error details returned when parsing the generation expression JSON. + json_err: serde_json::error::Error, + /// Generation expression. + line: String, + }, + #[error("Table metadata is invalid: {0}")] MetadataError(String), diff --git a/crates/core/src/kernel/models/actions.rs b/crates/core/src/kernel/models/actions.rs index 3812dc4838..d825d5bec4 100644 --- a/crates/core/src/kernel/models/actions.rs +++ b/crates/core/src/kernel/models/actions.rs @@ -2,12 +2,14 @@ use std::collections::{HashMap, HashSet}; use std::fmt::{self, Display}; use std::str::FromStr; +use delta_kernel::schema::{DataType, StructField}; use maplit::hashset; use serde::{Deserialize, Serialize}; use tracing::warn; use url::Url; use super::schema::StructType; +use super::StructTypeExt; use crate::kernel::{error::Error, DeltaResult}; use crate::TableProperty; use delta_kernel::table_features::{ReaderFeatures, WriterFeatures}; @@ -115,6 +117,19 @@ impl Metadata { } } +/// checks if table contains timestamp_ntz in any field including nested fields. +pub fn contains_timestampntz<'a>(mut fields: impl Iterator) -> bool { + fn _check_type(dtype: &DataType) -> bool { + match dtype { + &DataType::TIMESTAMP_NTZ => true, + DataType::Array(inner) => _check_type(inner.element_type()), + DataType::Struct(inner) => inner.fields().any(|f| _check_type(f.data_type())), + _ => false, + } + } + fields.any(|f| _check_type(f.data_type())) +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] #[serde(rename_all = "camelCase")] /// Defines a protocol action @@ -146,8 +161,8 @@ impl Protocol { } } - /// set the reader features in the protocol action, automatically bumps min_reader_version - pub fn with_reader_features( + /// Append the reader features in the protocol action, automatically bumps min_reader_version + pub fn append_reader_features( mut self, reader_features: impl IntoIterator>, ) -> Self { @@ -156,14 +171,20 @@ impl Protocol { .map(Into::into) .collect::>(); if !all_reader_features.is_empty() { - self.min_reader_version = 3 + self.min_reader_version = 3; + match self.reader_features { + Some(mut features) => { + features.extend(all_reader_features); + self.reader_features = Some(features); + } + None => self.reader_features = Some(all_reader_features), + }; } - self.reader_features = Some(all_reader_features); self } - /// set the writer features in the protocol action, automatically bumps min_writer_version - pub fn with_writer_features( + /// Append the writer features in the protocol action, automatically bumps min_writer_version + pub fn append_writer_features( mut self, writer_features: impl IntoIterator>, ) -> Self { @@ -172,9 +193,16 @@ impl Protocol { .map(|c| c.into()) .collect::>(); if !all_writer_feautures.is_empty() { - self.min_writer_version = 7 + self.min_writer_version = 7; + + match self.writer_features { + Some(mut features) => { + features.extend(all_writer_feautures); + self.writer_features = Some(features); + } + None => self.writer_features = Some(all_writer_feautures), + }; } - self.writer_features = Some(all_writer_feautures); self } @@ -255,6 +283,32 @@ impl Protocol { } self } + + /// Will apply the column metadata to the protocol by either bumping the version or setting + /// features + pub fn apply_column_metadata_to_protocol( + mut self, + schema: &StructType, + ) -> DeltaResult { + let generated_cols = schema.get_generated_columns()?; + let invariants = schema.get_invariants()?; + let contains_timestamp_ntz = self.contains_timestampntz(schema.fields()); + + if contains_timestamp_ntz { + self = self.enable_timestamp_ntz() + } + + if !generated_cols.is_empty() { + self = self.enable_generated_columns() + } + + if !invariants.is_empty() { + self = self.enable_invariants() + } + + Ok(self) + } + /// Will apply the properties to the protocol by either bumping the version or setting /// features pub fn apply_properties_to_protocol( @@ -391,10 +445,35 @@ impl Protocol { } Ok(self) } + + /// checks if table contains timestamp_ntz in any field including nested fields. + fn contains_timestampntz<'a>(&self, fields: impl Iterator) -> bool { + contains_timestampntz(fields) + } + /// Enable timestamp_ntz in the protocol - pub fn enable_timestamp_ntz(mut self) -> Protocol { - self = self.with_reader_features(vec![ReaderFeatures::TimestampWithoutTimezone]); - self = self.with_writer_features(vec![WriterFeatures::TimestampWithoutTimezone]); + fn enable_timestamp_ntz(mut self) -> Self { + self = self.append_reader_features([ReaderFeatures::TimestampWithoutTimezone]); + self = self.append_writer_features([WriterFeatures::TimestampWithoutTimezone]); + self + } + + /// Enabled generated columns + fn enable_generated_columns(mut self) -> Self { + if self.min_writer_version < 4 { + self.min_writer_version = 4; + } + if self.min_writer_version >= 7 { + self = self.append_writer_features([WriterFeatures::GeneratedColumns]); + } + self + } + + /// Enabled generated columns + fn enable_invariants(mut self) -> Self { + if self.min_writer_version >= 7 { + self = self.append_writer_features([WriterFeatures::Invariants]); + } self } } diff --git a/crates/core/src/kernel/models/schema.rs b/crates/core/src/kernel/models/schema.rs index 3a88564f1d..976fe467ef 100644 --- a/crates/core/src/kernel/models/schema.rs +++ b/crates/core/src/kernel/models/schema.rs @@ -10,6 +10,7 @@ use serde_json::Value; use crate::kernel::error::Error; use crate::kernel::DataCheck; +use crate::table::GeneratedColumn; /// Type alias for a top level schema pub type Schema = StructType; @@ -49,9 +50,80 @@ impl DataCheck for Invariant { pub trait StructTypeExt { /// Get all invariants in the schemas fn get_invariants(&self) -> Result, Error>; + + /// Get all generated column expressions + fn get_generated_columns(&self) -> Result, Error>; } impl StructTypeExt for StructType { + /// Get all invariants in the schemas + fn get_generated_columns(&self) -> Result, Error> { + let mut remaining_fields: Vec<(String, StructField)> = self + .fields() + .map(|field| (field.name.clone(), field.clone())) + .collect(); + let mut generated_cols: Vec = Vec::new(); + + let add_segment = |prefix: &str, segment: &str| -> String { + if prefix.is_empty() { + segment.to_owned() + } else { + format!("{prefix}.{segment}") + } + }; + + while let Some((field_path, field)) = remaining_fields.pop() { + match field.data_type() { + DataType::Struct(inner) => { + remaining_fields.extend( + inner + .fields() + .map(|field| { + let new_prefix = add_segment(&field_path, &field.name); + (new_prefix, field.clone()) + }) + .collect::>(), + ); + } + DataType::Array(inner) => { + let element_field_name = add_segment(&field_path, "element"); + remaining_fields.push(( + element_field_name, + StructField::new("".to_string(), inner.element_type.clone(), false), + )); + } + DataType::Map(inner) => { + let key_field_name = add_segment(&field_path, "key"); + remaining_fields.push(( + key_field_name, + StructField::new("".to_string(), inner.key_type.clone(), false), + )); + let value_field_name = add_segment(&field_path, "value"); + remaining_fields.push(( + value_field_name, + StructField::new("".to_string(), inner.value_type.clone(), false), + )); + } + _ => {} + } + if let Some(MetadataValue::String(generated_col_string)) = field + .metadata + .get(ColumnMetadataKey::GenerationExpression.as_ref()) + { + let json: Value = serde_json::from_str(generated_col_string).map_err(|e| { + Error::InvalidGenerationExpressionJson { + json_err: e, + line: generated_col_string.to_string(), + } + })?; + if let Value::String(sql) = json { + generated_cols.push(GeneratedColumn::new(&field_path, &sql)); + } + } + } + Ok(generated_cols) + } + /// Get all invariants in the schemas fn get_invariants(&self) -> Result, Error> { let mut remaining_fields: Vec<(String, StructField)> = self diff --git a/crates/core/src/operations/add_column.rs b/crates/core/src/operations/add_column.rs index 2b6d9de7df..a3477405af 100644 --- a/crates/core/src/operations/add_column.rs +++ b/crates/core/src/operations/add_column.rs @@ -88,24 +88,12 @@ impl std::future::IntoFuture for AddColumnBuilder { let table_schema = this.snapshot.schema(); let new_table_schema = merge_delta_struct(table_schema, fields_right)?; - // TODO(ion): Think of a way how we can simply this checking through the API or centralize some checks. - let contains_timestampntz = PROTOCOL.contains_timestampntz(fields.iter()); - let protocol = this.snapshot.protocol(); - - let maybe_new_protocol = if contains_timestampntz { - let updated_protocol = protocol.clone().enable_timestamp_ntz(); - if !(protocol.min_reader_version == 3 && protocol.min_writer_version == 7) { - // Convert existing properties to features since we advanced the protocol to v3,7 - Some( - updated_protocol - .move_table_properties_into_features(&metadata.configuration), - ) - } else { - Some(updated_protocol) - } - } else { - None - }; + let current_protocol = this.snapshot.protocol(); + + let new_protocol = current_protocol + .clone() + .apply_column_metadata_to_protocol(&new_table_schema)? + .move_table_properties_into_features(&metadata.configuration); let operation = DeltaOperation::AddColumn { fields: fields.into_iter().collect_vec(), @@ -115,7 +103,7 @@ impl std::future::IntoFuture for AddColumnBuilder { let mut actions = vec![metadata.into()]; - if let Some(new_protocol) = maybe_new_protocol { + if current_protocol != &new_protocol { actions.push(new_protocol.into()) } diff --git a/crates/core/src/operations/add_feature.rs b/crates/core/src/operations/add_feature.rs index 0e7f88ee7f..31dbb928bf 100644 --- a/crates/core/src/operations/add_feature.rs +++ b/crates/core/src/operations/add_feature.rs @@ -123,8 +123,8 @@ impl std::future::IntoFuture for AddTableFeatureBuilder { } } - protocol = protocol.with_reader_features(reader_features); - protocol = protocol.with_writer_features(writer_features); + protocol = protocol.append_reader_features(reader_features); + protocol = protocol.append_writer_features(writer_features); let operation = DeltaOperation::AddFeature { name: name.to_vec(), diff --git a/crates/core/src/operations/cdc.rs b/crates/core/src/operations/cdc.rs index c9d0ca0665..5e950402b8 100644 --- a/crates/core/src/operations/cdc.rs +++ b/crates/core/src/operations/cdc.rs @@ -175,7 +175,7 @@ mod tests { #[tokio::test] async fn test_should_write_cdc_v7_table_with_writer_feature() { let protocol = - Protocol::new(1, 7).with_writer_features(vec![WriterFeatures::ChangeDataFeed]); + Protocol::new(1, 7).append_writer_features(vec![WriterFeatures::ChangeDataFeed]); let actions = vec![Action::Protocol(protocol)]; let mut table: DeltaTable = DeltaOps::new_in_memory() .create() diff --git a/crates/core/src/operations/create.rs b/crates/core/src/operations/create.rs index 5f6ef47bc0..bcf79650cf 100644 --- a/crates/core/src/operations/create.rs +++ b/crates/core/src/operations/create.rs @@ -289,24 +289,12 @@ impl CreateBuilder { self.pre_execute(operation_id).await?; let configuration = self.configuration; - let contains_timestampntz = PROTOCOL.contains_timestampntz(self.columns.iter()); - // TODO configure more permissive versions based on configuration. Also how should this ideally be handled? - // We set the lowest protocol we can, and if subsequent writes use newer features we update metadata? - - let current_protocol = if contains_timestampntz { - Protocol { - min_reader_version: 3, - min_writer_version: 7, - writer_features: Some(hashset! {WriterFeatures::TimestampWithoutTimezone}), - reader_features: Some(hashset! {ReaderFeatures::TimestampWithoutTimezone}), - } - } else { - Protocol { - min_reader_version: PROTOCOL.default_reader_version(), - min_writer_version: PROTOCOL.default_writer_version(), - reader_features: None, - writer_features: None, - } + + let current_protocol = Protocol { + min_reader_version: PROTOCOL.default_reader_version(), + min_writer_version: PROTOCOL.default_writer_version(), + reader_features: None, + writer_features: None, }; let protocol = self @@ -319,18 +307,21 @@ impl CreateBuilder { }) .unwrap_or_else(|| current_protocol); - let protocol = protocol.apply_properties_to_protocol( - &configuration - .iter() - .map(|(k, v)| (k.clone(), v.clone().unwrap())) - .collect::>(), - self.raise_if_key_not_exists, - )?; + let schema = StructType::new(self.columns); - let protocol = protocol.move_table_properties_into_features(&configuration); + let protocol = protocol + .apply_properties_to_protocol( + &configuration + .iter() + .map(|(k, v)| (k.clone(), v.clone().unwrap())) + .collect::>(), + self.raise_if_key_not_exists, + )? + .apply_column_metadata_to_protocol(&schema)? + .move_table_properties_into_features(&configuration); let mut metadata = Metadata::try_new( - StructType::new(self.columns), + schema, self.partition_columns.unwrap_or_default(), configuration, )? diff --git a/crates/core/src/operations/transaction/protocol.rs b/crates/core/src/operations/transaction/protocol.rs index ef88fbf8e6..bb49e0fae9 100644 --- a/crates/core/src/operations/transaction/protocol.rs +++ b/crates/core/src/operations/transaction/protocol.rs @@ -5,7 +5,7 @@ use once_cell::sync::Lazy; use tracing::log::*; use super::{TableReference, TransactionError}; -use crate::kernel::{Action, DataType, EagerSnapshot, Schema, StructField}; +use crate::kernel::{contains_timestampntz, Action, DataType, EagerSnapshot, Schema, StructField}; use crate::protocol::DeltaOperation; use crate::table::state::DeltaTableState; use delta_kernel::table_features::{ReaderFeatures, WriterFeatures}; @@ -79,29 +79,13 @@ impl ProtocolChecker { Ok(()) } - /// checks if table contains timestamp_ntz in any field including nested fields. - pub fn contains_timestampntz<'a>( - &self, - mut fields: impl Iterator, - ) -> bool { - fn _check_type(dtype: &DataType) -> bool { - match dtype { - &DataType::TIMESTAMP_NTZ => true, - DataType::Array(inner) => _check_type(inner.element_type()), - DataType::Struct(inner) => inner.fields().any(|f| _check_type(f.data_type())), - _ => false, - } - } - fields.any(|f| _check_type(f.data_type())) - } - /// Check can write_timestamp_ntz pub fn check_can_write_timestamp_ntz( &self, snapshot: &DeltaTableState, schema: &Schema, ) -> Result<(), TransactionError> { - let contains_timestampntz = self.contains_timestampntz(schema.fields()); + let contains_timestampntz = contains_timestampntz(schema.fields()); let required_features: Option<&HashSet> = match snapshot.protocol().min_writer_version { 0..=6 => None, @@ -159,22 +143,6 @@ impl ProtocolChecker { _ => snapshot.protocol().writer_features.as_ref(), }; - if (4..7).contains(&min_writer_version) { - debug!("min_writer_version is less 4-6, checking for unsupported table features"); - if let Ok(schema) = snapshot.metadata().schema() { - for field in schema.fields() { - if field.metadata.contains_key( - crate::kernel::ColumnMetadataKey::GenerationExpression.as_ref(), - ) { - error!("The table contains `delta.generationExpression` settings on columns which mean this table cannot be currently written to by delta-rs"); - return Err(TransactionError::UnsupportedWriterFeatures(vec![ - WriterFeatures::GeneratedColumns, - ])); - } - } - } - } - if let Some(features) = required_features { let mut diff = features.difference(&self.writer_features).peekable(); if diff.peek().is_some() { @@ -246,15 +214,13 @@ pub static INSTANCE: Lazy = Lazy::new(|| { #[cfg(feature = "cdf")] { writer_features.insert(WriterFeatures::ChangeDataFeed); - writer_features.insert(WriterFeatures::GeneratedColumns); } #[cfg(feature = "datafusion")] { writer_features.insert(WriterFeatures::Invariants); writer_features.insert(WriterFeatures::CheckConstraints); + writer_features.insert(WriterFeatures::GeneratedColumns); } - // writer_features.insert(WriterFeatures::ChangeDataFeed); - // writer_features.insert(WriterFeatures::GeneratedColumns); // writer_features.insert(WriterFeatures::ColumnMapping); // writer_features.insert(WriterFeatures::IdentityColumns); @@ -584,7 +550,7 @@ mod tests { let checker_5 = ProtocolChecker::new(READER_V2.clone(), WRITER_V4.clone()); let actions = vec![ Action::Protocol( - Protocol::new(2, 4).with_writer_features(vec![WriterFeatures::ChangeDataFeed]), + Protocol::new(2, 4).append_writer_features(vec![WriterFeatures::ChangeDataFeed]), ), metadata_action(None).into(), ]; @@ -601,7 +567,7 @@ mod tests { let checker_5 = ProtocolChecker::new(READER_V2.clone(), WRITER_V4.clone()); let actions = vec![ Action::Protocol( - Protocol::new(2, 4).with_writer_features(vec![WriterFeatures::GeneratedColumns]), + Protocol::new(2, 4).append_writer_features([WriterFeatures::GeneratedColumns]), ), metadata_action(None).into(), ]; diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 0be11f05cf..1a1cc9d11f 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -25,6 +25,7 @@ //! ```` use std::collections::HashMap; +use std::hash::Hash; use std::str::FromStr; use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; @@ -35,7 +36,7 @@ use arrow_cast::can_cast_types; use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion_common::DFSchema; -use datafusion_expr::{lit, Expr}; +use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::expressions::{self}; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::filter::FilterExec; @@ -63,14 +64,15 @@ use crate::delta_datafusion::{ use crate::delta_datafusion::{DataFusionMixins, DeltaDataChecker}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{ - Action, ActionType, Add, AddCDCFile, Metadata, PartitionsExt, Remove, StructType, + Action, ActionType, Add, AddCDCFile, DataCheck, Metadata, PartitionsExt, Remove, StructType, + StructTypeExt, }; use crate::logstore::LogStoreRef; use crate::operations::cast::{cast_record_batch, merge_schema::merge_arrow_schema}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; -use crate::table::Constraint as DeltaConstraint; +use crate::table::{Constraint as DeltaConstraint, GeneratedColumn}; use crate::writer::record_batch::divide_by_partition_values; use crate::DeltaTable; @@ -852,8 +854,14 @@ impl std::future::IntoFuture for WriteBuilder { } else { Ok(this.partition_columns.unwrap_or_default()) }?; + + let generated_col_expressions = this + .snapshot + .as_ref() + .map(|v| v.schema().get_generated_columns().unwrap_or_default()) + .unwrap_or_default(); let mut schema_drift = false; - let plan = if let Some(plan) = this.input { + let mut plan = if let Some(plan) = this.input { if this.schema_mode == Some(SchemaMode::Merge) { return Err(DeltaTableError::Generic( "Schema merge not supported yet for Datafusion".to_string(), @@ -864,11 +872,22 @@ impl std::future::IntoFuture for WriteBuilder { if batches.is_empty() { Err(WriteError::MissingData) } else { - let schema = batches[0].schema(); + let mut schema = batches[0].schema(); + // Schema merging code should be aware of columns that can be generated during write + // so they might be empty in the batch, but the will exist in the input_schema() + // in this case we have to insert the generated column and it's type in the schema of the batch let mut new_schema = None; if let Some(snapshot) = &this.snapshot { let table_schema = snapshot.input_schema()?; + + // Merge schema's initial round when there are generated columns expressions + // This is to have the batch schema be the same as the input schema without adding new fields + // from the incoming batch + if !generated_col_expressions.is_empty() { + schema = merge_arrow_schema(table_schema.clone(), schema, true)?; + } + if let Err(schema_err) = try_cast_batch(schema.fields(), table_schema.fields()) { @@ -876,7 +895,11 @@ impl std::future::IntoFuture for WriteBuilder { if this.mode == SaveMode::Overwrite && this.schema_mode == Some(SchemaMode::Overwrite) { - new_schema = None // we overwrite anyway, so no need to cast + if generated_col_expressions.is_empty() { + new_schema = None // we overwrite anyway, so no need to cast + } else { + new_schema = Some(schema.clone()) // we need to cast the batch to include the generated col as empty null + } } else if this.schema_mode == Some(SchemaMode::Merge) { new_schema = Some(merge_arrow_schema( table_schema.clone(), @@ -889,7 +912,11 @@ impl std::future::IntoFuture for WriteBuilder { } else if this.mode == SaveMode::Overwrite && this.schema_mode == Some(SchemaMode::Overwrite) { - new_schema = None // we overwrite anyway, so no need to cast + if generated_col_expressions.is_empty() { + new_schema = None // we overwrite anyway, so no need to cast + } else { + new_schema = Some(schema.clone()) // we need to cast the batch to include the generated col as empty null + } } else { // Schema needs to be merged so that utf8/binary/list types are preserved from the batch side if both table // and batch contains such type. Other types are preserved from the table side. @@ -912,7 +939,7 @@ impl std::future::IntoFuture for WriteBuilder { &batch, new_schema, this.safe_cast, - schema_drift, // Schema drifted so we have to add the missing columns/structfields. + schema_drift || !generated_col_expressions.is_empty(), // Schema drifted so we have to add the missing columns/structfields or missing generated cols.. )?, None => batch, }; @@ -949,7 +976,7 @@ impl std::future::IntoFuture for WriteBuilder { &batch, new_schema.clone(), this.safe_cast, - schema_drift, // Schema drifted so we have to add the missing columns/structfields. + schema_drift || !generated_col_expressions.is_empty(), // Schema drifted so we have to add the missing columns/structfields or missing generated cols. )?); num_added_rows += batch.num_rows(); } @@ -972,40 +999,25 @@ impl std::future::IntoFuture for WriteBuilder { } else { Err(WriteError::MissingData) }?; + let schema = plan.schema(); if this.schema_mode == Some(SchemaMode::Merge) && schema_drift { if let Some(snapshot) = &this.snapshot { let schema_struct: StructType = schema.clone().try_into()?; let current_protocol = snapshot.protocol(); let configuration = snapshot.metadata().configuration.clone(); - let maybe_new_protocol = if PROTOCOL - .contains_timestampntz(schema_struct.fields()) - && !current_protocol - .reader_features - .clone() - .unwrap_or_default() - .contains(&delta_kernel::table_features::ReaderFeatures::TimestampWithoutTimezone) - // We can check only reader features, as reader and writer timestampNtz - // should be always enabled together - { - let new_protocol = current_protocol.clone().enable_timestamp_ntz(); - if !(current_protocol.min_reader_version == 3 - && current_protocol.min_writer_version == 7) - { - Some(new_protocol.move_table_properties_into_features(&configuration)) - } else { - Some(new_protocol) - } - } else { - None - }; + let new_protocol = current_protocol + .clone() + .apply_column_metadata_to_protocol(&schema_struct)? + .move_table_properties_into_features(&configuration); + let schema_action = Action::Metadata(Metadata::try_new( schema_struct, partition_columns.clone(), configuration, )?); actions.push(schema_action); - if let Some(new_protocol) = maybe_new_protocol { + if current_protocol != &new_protocol { actions.push(new_protocol.into()) } } @@ -1019,6 +1031,55 @@ impl std::future::IntoFuture for WriteBuilder { } }; + // Add when.then expr for generated columns + if !generated_col_expressions.is_empty() { + fn create_field( + idx: usize, + field: &arrow_schema::Field, + generated_cols_map: &HashMap, + state: &datafusion::execution::session_state::SessionState, + dfschema: &DFSchema, + ) -> DeltaResult<(Arc, String)> { + match generated_cols_map.get(field.name()) { + Some(generated_col) => { + let generation_expr = state.create_physical_expr( + when( + col(generated_col.get_name()).is_null(), + state.create_logical_expr( + generated_col.get_generation_expression(), + dfschema, + )?, + ) + .otherwise(col(generated_col.get_name()))?, + dfschema, + )?; + Ok((generation_expr, field.name().to_owned())) + } + None => Ok(( + Arc::new(expressions::Column::new(field.name(), idx)), + field.name().to_owned(), + )), + } + } + + let dfschema: DFSchema = schema.as_ref().clone().try_into()?; + let generated_cols_map = generated_col_expressions + .into_iter() + .map(|v| (v.name.clone(), v)) + .collect::>(); + let current_fields: DeltaResult, String)>> = plan + .schema() + .fields() + .into_iter() + .enumerate() + .map(|(idx, field)| { + create_field(idx, field, &generated_cols_map, &state, &dfschema) + }) + .collect(); + + plan = Arc::new(ProjectionExec::try_new(current_fields?, plan.clone())?); + }; + let (predicate_str, predicate) = match this.predicate { Some(predicate) => { let pred = match predicate { @@ -1074,43 +1135,25 @@ impl std::future::IntoFuture for WriteBuilder { // Update metadata with new schema let table_schema = snapshot.input_schema()?; - let configuration = snapshot.metadata().configuration.clone(); - let current_protocol = snapshot.protocol(); - let maybe_new_protocol = if PROTOCOL.contains_timestampntz( - TryInto::::try_into(schema.clone())?.fields(), - ) && !current_protocol - .reader_features - .clone() - .unwrap_or_default() - .contains( - &delta_kernel::table_features::ReaderFeatures::TimestampWithoutTimezone, - ) - // We can check only reader features, as reader and writer timestampNtz - // should be always enabled together - { - let new_protocol = current_protocol.clone().enable_timestamp_ntz(); - if !(current_protocol.min_reader_version == 3 - && current_protocol.min_writer_version == 7) - { - Some(new_protocol.move_table_properties_into_features(&configuration)) - } else { - Some(new_protocol) - } - } else { - None - }; - - if let Some(protocol) = maybe_new_protocol { - actions.push(protocol.into()) - } - + let delta_schema: StructType = schema.as_ref().try_into()?; if schema != table_schema { let mut metadata = snapshot.metadata().clone(); - let delta_schema: StructType = schema.as_ref().try_into()?; + metadata.schema_string = serde_json::to_string(&delta_schema)?; actions.push(Action::Metadata(metadata)); } + let configuration = snapshot.metadata().configuration.clone(); + let current_protocol = snapshot.protocol(); + let new_protocol = current_protocol + .clone() + .apply_column_metadata_to_protocol(&delta_schema)? + .move_table_properties_into_features(&configuration); + + if current_protocol != &new_protocol { + actions.push(new_protocol.into()) + } + let deletion_timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index 1409c498c2..56bfd664cb 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -157,6 +157,42 @@ impl DataCheck for Constraint { } } +/// A generated column +#[derive(Eq, PartialEq, Debug, Default, Clone)] +pub struct GeneratedColumn { + /// The full path to the field. + pub name: String, + /// The SQL string that generate the column value. + pub generation_expr: String, + /// The SQL string that must always evaluate to true. + pub validation_expr: String, +} + +impl GeneratedColumn { + /// Create a new invariant + pub fn new(field_name: &str, sql_generation: &str) -> Self { + Self { + name: field_name.to_string(), + generation_expr: sql_generation.to_string(), + validation_expr: format!("{} <=> {}", field_name, sql_generation), + } + } + + pub fn get_generation_expression(&self) -> &str { + &self.generation_expr + } +} + +impl DataCheck for GeneratedColumn { + fn get_name(&self) -> &str { + &self.name + } + + fn get_expression(&self) -> &str { + &self.validation_expr + } +} + /// Return partition fields along with their data type from the current schema. pub(crate) fn get_partition_col_data_types<'a>( schema: &'a StructType,