diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 220c5b21d9..8b80dc0a9f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,8 +42,6 @@ jobs: toolchain: stable override: true - - uses: Swatinem/rust-cache@v2 - - name: build and lint with clippy run: cargo clippy --features azure,datafusion,s3,gcs,glue --tests @@ -82,8 +80,6 @@ jobs: toolchain: "stable" override: true - - uses: Swatinem/rust-cache@v2 - - name: Run tests run: cargo test --verbose --features datafusion,azure @@ -118,22 +114,6 @@ jobs: toolchain: stable override: true - # - uses: actions/setup-java@v3 - # with: - # distribution: "zulu" - # java-version: "17" - - # - uses: beyondstorage/setup-hdfs@master - # with: - # hdfs-version: "3.3.2" - - # - name: Set Hadoop env - # run: | - # echo "CLASSPATH=$CLASSPATH:`hadoop classpath --glob`" >> $GITHUB_ENV - # echo "LD_LIBRARY_PATH=$JAVA_HOME/lib/server" >> $GITHUB_ENV - - - uses: Swatinem/rust-cache@v2 - - name: Start emulated services run: docker-compose up -d diff --git a/.gitignore b/.gitignore index ca0576b47c..84fc17c5f2 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ tlaplus/*.toolbox/*/[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*/ /.idea .vscode .env +.venv **/.DS_Store **/.python-version .coverage @@ -29,4 +30,4 @@ Cargo.lock justfile site -__pycache__ \ No newline at end of file +__pycache__ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ee258a3ce8..4472a3640a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ # Contributing to delta-rs -Development on this project is mostly driven by volunteer contributors. We welcome new contributors, including not only those who develop new features, but also those who are able to help with documentation and provide detailed bug reports. +Development on this project is mostly driven by volunteer contributors. We welcome new contributors, including not only those who develop new features, but also those who are able to help with documentation and provide detailed bug reports. Please take note of our [code of conduct](CODE_OF_CONDUCT.md). @@ -31,7 +31,7 @@ python -m pytest tests/test_writer.py -s -k "test_with_deltalake_schema" - Run some Rust code, e.g. run an example ``` cd crates/deltalake -cargo run --examples basic_operations +cargo run --example basic_operations --features="datafusion" ``` ## Run the docs locally diff --git a/Cargo.toml b/Cargo.toml index a6819cde0b..3a99e8ee63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ arrow-ord = { version = "50" } arrow-row = { version = "50" } arrow-schema = { version = "50" } arrow-select = { version = "50" } -object_store = { version = "0.9" } +object_store = { version = "=0.9.0" } parquet = { version = "50" } # datafusion @@ -59,7 +59,7 @@ serde_json = "1" # "stdlib" bytes = { version = "1" } -chrono = { version = "0.4.31", default-features = false, features = ["clock"] } +chrono = { version = "=0.4.34", default-features = false, features = ["clock"] } tracing = { version = "0.1", features = ["log"] } regex = { version = "1" } thiserror = { version = "1" } diff --git a/README.md b/README.md index 927b68ee63..100d34355d 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ write_deltalake("./data/delta", df) dt = DeltaTable("./data/delta") df2 = dt.to_pandas() -assert df == df2 +assert df.equals(df2) ``` The same table can also be loaded using the core Rust crate: @@ -91,7 +91,7 @@ async fn main() -> Result<(), DeltaTableError> { let table = open_table("./data/delta").await?; // show all active files in the table - let files = table.get_files(); + let files: Vec<_> = table.get_file_uris()?.collect(); println!("{:?}", files); Ok(()) @@ -130,45 +130,45 @@ of features outlined in the Delta [protocol][protocol] is also [tracked](#protoc ### Cloud Integrations -| Storage | Rust | Python | Comment | -| -------------------- | :-----: | :-----: | ----------------------------------- | -| Local | ![done] | ![done] | | -| S3 - AWS | ![done] | ![done] | requires lock for concurrent writes | -| S3 - MinIO | ![done] | ![done] | requires lock for concurrent writes | -| S3 - R2 | ![done] | ![done] | requires lock for concurrent writes | -| Azure Blob | ![done] | ![done] | | -| Azure ADLS Gen2 | ![done] | ![done] | | -| Microsoft OneLake | ![done] | ![done] | | -| Google Cloud Storage | ![done] | ![done] | | +| Storage | Rust | Python | Comment | +| -------------------- | :-----: | :-----: | ---------------------------------------------------------------- | +| Local | ![done] | ![done] | | +| S3 - AWS | ![done] | ![done] | requires lock for concurrent writes | +| S3 - MinIO | ![done] | ![done] | requires lock for concurrent writes | +| S3 - R2 | ![done] | ![done] | No lock required when using `AmazonS3ConfigKey::CopyIfNotExists` | +| Azure Blob | ![done] | ![done] | | +| Azure ADLS Gen2 | ![done] | ![done] | | +| Microsoft OneLake | ![done] | ![done] | | +| Google Cloud Storage | ![done] | ![done] | | ### Supported Operations -| Operation | Rust | Python | Description | -| --------------------- | :----------------------: | :----------------------: | ------------------------------------------- | -| Create | ![done] | ![done] | Create a new table | -| Read | ![done] | ![done] | Read data from a table | -| Vacuum | ![done] | ![done] | Remove unused files and log entries | -| Delete - partitions | | ![done] | Delete a table partition | -| Delete - predicates | ![done] | ![done] | Delete data based on a predicate | -| Optimize - compaction | ![done] | ![done] | Harmonize the size of data file | -| Optimize - Z-order | ![done] | ![done] | Place similar data into the same file | -| Merge | ![done] | ![done] | Merge a target Delta table with source data | -| FS check | ![done] | ![done] | Remove corrupted files from table | +| Operation | Rust | Python | Description | +| --------------------- | :-----: | :-----: | ------------------------------------------- | +| Create | ![done] | ![done] | Create a new table | +| Read | ![done] | ![done] | Read data from a table | +| Vacuum | ![done] | ![done] | Remove unused files and log entries | +| Delete - partitions | | ![done] | Delete a table partition | +| Delete - predicates | ![done] | ![done] | Delete data based on a predicate | +| Optimize - compaction | ![done] | ![done] | Harmonize the size of data file | +| Optimize - Z-order | ![done] | ![done] | Place similar data into the same file | +| Merge | ![done] | ![done] | Merge a target Delta table with source data | +| FS check | ![done] | ![done] | Remove corrupted files from table | ### Protocol Support Level -| Writer Version | Requirement | Status | -| -------------- | --------------------------------------------- | :------------------: | -| Version 2 | Append Only Tables | ![done] | -| Version 2 | Column Invariants | ![done] | -| Version 3 | Enforce `delta.checkpoint.writeStatsAsJson` | [![open]][writer-rs] | -| Version 3 | Enforce `delta.checkpoint.writeStatsAsStruct` | [![open]][writer-rs] | +| Writer Version | Requirement | Status | +| -------------- | --------------------------------------------- | :-------------------------------: | +| Version 2 | Append Only Tables | ![done] | +| Version 2 | Column Invariants | ![done] | +| Version 3 | Enforce `delta.checkpoint.writeStatsAsJson` | [![open]][writer-rs] | +| Version 3 | Enforce `delta.checkpoint.writeStatsAsStruct` | [![open]][writer-rs] | | Version 3 | CHECK constraints | [![semi-done]][check-constraints] | -| Version 4 | Change Data Feed | | -| Version 4 | Generated Columns | | -| Version 5 | Column Mapping | | -| Version 6 | Identity Columns | | -| Version 7 | Table Features | | +| Version 4 | Change Data Feed | | +| Version 4 | Generated Columns | | +| Version 5 | Column Mapping | | +| Version 6 | Identity Columns | | +| Version 7 | Table Features | | | Reader Version | Requirement | Status | | -------------- | ----------------------------------- | ------ | diff --git a/crates/aws/Cargo.toml b/crates/aws/Cargo.toml index 5d2bd5d69e..169686c240 100644 --- a/crates/aws/Cargo.toml +++ b/crates/aws/Cargo.toml @@ -13,10 +13,12 @@ rust-version.workspace = true [dependencies] deltalake-core = { version = "0.17.0", path = "../core" } -rusoto_core = { version = "0.47", default-features = false, optional = true } -rusoto_credential = { version = "0.47" } -rusoto_sts = { version = "0.47", default-features = false, optional = true } -rusoto_dynamodb = { version = "0.47", default-features = false, optional = true } +aws-smithy-runtime-api = { version="1.1.7" } +aws-smithy-runtime = { version="1.1.7", optional = true} +aws-credential-types = { version="1.1.7", features = ["hardcoded-credentials"]} +aws-config = { version = "1.1.6", default-features = false, features = ["behavior-version-latest","rt-tokio", "credentials-process", "sso"] } +aws-sdk-dynamodb = {version = "1.15.0", default-features = false, features = ["behavior-version-latest", "rt-tokio"] } +aws-sdk-sts = {version = "1.1.6", default-features = false, features = ["behavior-version-latest", "rt-tokio"] } lazy_static = "1" maplit = "1" @@ -32,6 +34,7 @@ regex = { workspace = true } uuid = { workspace = true, features = ["serde", "v4"] } url = { workspace = true } backoff = { version = "0.4", features = [ "tokio" ] } +hyper-tls = { version = "0.5", optional = true } [dev-dependencies] deltalake-core = { path = "../core", features = ["datafusion"] } @@ -46,12 +49,13 @@ serde_json = { workspace = true } default = ["rustls"] integration_test = [] native-tls = [ - "rusoto_core/native-tls", - "rusoto_sts/native-tls", - "rusoto_dynamodb/native-tls", + "aws-config/client-hyper", + "aws-smithy-runtime/connector-hyper-0-14-x", + "hyper-tls" ] rustls = [ - "rusoto_core/rustls", - "rusoto_sts/rustls", - "rusoto_dynamodb/rustls", + "aws-config/client-hyper", + "aws-config/rustls", + "aws-sdk-dynamodb/rustls", + "aws-sdk-sts/rustls", ] diff --git a/crates/aws/src/errors.rs b/crates/aws/src/errors.rs index bbce9dc426..a79542bfff 100644 --- a/crates/aws/src/errors.rs +++ b/crates/aws/src/errors.rs @@ -2,27 +2,45 @@ use std::num::ParseIntError; -use rusoto_core::RusotoError; -use rusoto_dynamodb::{CreateTableError, GetItemError, PutItemError, QueryError, UpdateItemError}; - -#[derive(thiserror::Error, Debug, PartialEq)] -pub enum DynamoDbConfigError { - /// Error raised creating http client - #[error("Failed to create request dispatcher: {source}")] - HttpClient { - /// The underlying Rusoto TlsError - #[from] - source: rusoto_core::request::TlsError, +use aws_credential_types::provider::error::CredentialsError; +use aws_sdk_dynamodb::{ + error::SdkError, + operation::{ + create_table::CreateTableError, get_item::GetItemError, put_item::PutItemError, + query::QueryError, update_item::UpdateItemError, }, +}; +use aws_smithy_runtime_api::client::result::ServiceError; + +macro_rules! impl_from_service_error { + ($error_type:ty) => { + impl From> for LockClientError + where + R: Send + Sync + std::fmt::Debug + 'static, + { + fn from(err: SdkError<$error_type, R>) -> Self { + match err { + SdkError::ServiceError(e) => e.into(), + _ => LockClientError::GenericDynamoDb { + source: Box::new(err), + }, + } + } + } - /// Error raised getting credentials - #[error("Failed to retrieve AWS credentials: {source}")] - Credentials { - /// The underlying Rusoto CredentialsError - #[from] - source: rusoto_credential::CredentialsError, - }, + impl From> for LockClientError + where + R: Send + Sync + std::fmt::Debug + 'static, + { + fn from(value: ServiceError<$error_type, R>) -> Self { + value.into_err().into() + } + } + }; +} +#[derive(thiserror::Error, Debug)] +pub enum DynamoDbConfigError { /// Billing mode string invalid #[error("Invalid billing mode : {0}, supported values : ['provided', 'pay_per_request']")] InvalidBillingMode(String), @@ -33,6 +51,9 @@ pub enum DynamoDbConfigError { // config_value: String, source: ParseIntError, }, + /// Cannot initialize DynamoDbConfiguration due to some sort of threading issue + #[error("Cannot initialize dynamodb lock configuration")] + InitializationError, } /// Errors produced by `DynamoDbLockClient` @@ -44,7 +65,7 @@ pub enum LockClientError { #[error("Lock table '{name}': creation failed: {source}")] LockTableCreateFailure { name: String, - source: RusotoError, + source: Box, }, #[error("Log entry for table '{table_path}' and version '{version}' already exists")] @@ -60,12 +81,8 @@ pub enum LockClientError { GenericDynamoDb { source: Box, }, - #[error("configuration error: {source}")] - Credentials { - source: rusoto_credential::CredentialsError, - }, - + Credentials { source: CredentialsError }, #[error( "Atomic rename requires a LockClient for S3 backends. \ Either configure the LockClient, or set AWS_S3_ALLOW_UNSAFE_RENAME=true \ @@ -77,12 +94,14 @@ pub enum LockClientError { impl From for LockClientError { fn from(err: GetItemError) -> Self { match err { - GetItemError::InternalServerError(_) => err.into(), - GetItemError::ProvisionedThroughputExceeded(_) => { + GetItemError::ProvisionedThroughputExceededException(_) => { LockClientError::ProvisionedThroughputExceeded } GetItemError::RequestLimitExceeded(_) => LockClientError::ProvisionedThroughputExceeded, - GetItemError::ResourceNotFound(_) => LockClientError::LockTableNotFound, + GetItemError::ResourceNotFoundException(_) => LockClientError::LockTableNotFound, + _ => LockClientError::GenericDynamoDb { + source: Box::new(err), + }, } } } @@ -90,12 +109,14 @@ impl From for LockClientError { impl From for LockClientError { fn from(err: QueryError) -> Self { match err { - QueryError::InternalServerError(_) => err.into(), - QueryError::ProvisionedThroughputExceeded(_) => { + QueryError::ProvisionedThroughputExceededException(_) => { LockClientError::ProvisionedThroughputExceeded } QueryError::RequestLimitExceeded(_) => LockClientError::ProvisionedThroughputExceeded, - QueryError::ResourceNotFound(_) => LockClientError::LockTableNotFound, + QueryError::ResourceNotFoundException(_) => LockClientError::LockTableNotFound, + _ => LockClientError::GenericDynamoDb { + source: Box::new(err), + }, } } } @@ -103,17 +124,19 @@ impl From for LockClientError { impl From for LockClientError { fn from(err: PutItemError) -> Self { match err { - PutItemError::ConditionalCheckFailed(_) => { + PutItemError::ConditionalCheckFailedException(_) => { unreachable!("error must be handled explicitely") } - PutItemError::InternalServerError(_) => err.into(), - PutItemError::ProvisionedThroughputExceeded(_) => { + PutItemError::ProvisionedThroughputExceededException(_) => { LockClientError::ProvisionedThroughputExceeded } PutItemError::RequestLimitExceeded(_) => LockClientError::ProvisionedThroughputExceeded, - PutItemError::ResourceNotFound(_) => LockClientError::LockTableNotFound, - PutItemError::ItemCollectionSizeLimitExceeded(_) => err.into(), - PutItemError::TransactionConflict(_) => err.into(), + PutItemError::ResourceNotFoundException(_) => LockClientError::LockTableNotFound, + PutItemError::ItemCollectionSizeLimitExceededException(_) => err.into(), + PutItemError::TransactionConflictException(_) => err.into(), + _ => LockClientError::GenericDynamoDb { + source: Box::new(err), + }, } } } @@ -121,34 +144,27 @@ impl From for LockClientError { impl From for LockClientError { fn from(err: UpdateItemError) -> Self { match err { - UpdateItemError::ConditionalCheckFailed(_) => { + UpdateItemError::ConditionalCheckFailedException(_) => { unreachable!("condition check failure in update is not an error") } UpdateItemError::InternalServerError(_) => err.into(), - UpdateItemError::ProvisionedThroughputExceeded(_) => { + UpdateItemError::ProvisionedThroughputExceededException(_) => { LockClientError::ProvisionedThroughputExceeded } UpdateItemError::RequestLimitExceeded(_) => { LockClientError::ProvisionedThroughputExceeded } - UpdateItemError::ResourceNotFound(_) => LockClientError::LockTableNotFound, - UpdateItemError::ItemCollectionSizeLimitExceeded(_) => err.into(), - UpdateItemError::TransactionConflict(_) => err.into(), - } - } -} - -impl From> for LockClientError -where - E: Into + std::error::Error + Send + Sync + 'static, -{ - fn from(err: RusotoError) -> Self { - match err { - RusotoError::Service(e) => e.into(), - RusotoError::Credentials(e) => LockClientError::Credentials { source: e }, + UpdateItemError::ResourceNotFoundException(_) => LockClientError::LockTableNotFound, + UpdateItemError::ItemCollectionSizeLimitExceededException(_) => err.into(), + UpdateItemError::TransactionConflictException(_) => err.into(), _ => LockClientError::GenericDynamoDb { source: Box::new(err), }, } } } + +impl_from_service_error!(GetItemError); +impl_from_service_error!(PutItemError); +impl_from_service_error!(QueryError); +impl_from_service_error!(UpdateItemError); diff --git a/crates/aws/src/lib.rs b/crates/aws/src/lib.rs index 2630f80512..026f0e0df9 100644 --- a/crates/aws/src/lib.rs +++ b/crates/aws/src/lib.rs @@ -2,8 +2,21 @@ pub mod errors; pub mod logstore; +#[cfg(feature = "native-tls")] +mod native; pub mod storage; - +use aws_config::SdkConfig; +use aws_sdk_dynamodb::{ + operation::{ + create_table::CreateTableError, get_item::GetItemError, put_item::PutItemError, + query::QueryError, update_item::UpdateItemError, + }, + types::{ + AttributeDefinition, AttributeValue, BillingMode, KeySchemaElement, KeyType, + ScalarAttributeType, + }, + Client, +}; use lazy_static::lazy_static; use object_store::aws::AmazonS3ConfigKey; use regex::Regex; @@ -18,14 +31,6 @@ use tracing::debug; use deltalake_core::logstore::{logstores, LogStore, LogStoreFactory}; use deltalake_core::storage::{factories, url_prefix_handler, ObjectStoreRef, StorageOptions}; use deltalake_core::{DeltaResult, Path}; -use rusoto_core::{HttpClient, Region, RusotoError}; -use rusoto_credential::AutoRefreshingProvider; -use rusoto_dynamodb::{ - AttributeDefinition, AttributeValue, CreateTableError, CreateTableInput, DynamoDb, - DynamoDbClient, GetItemError, GetItemInput, KeySchemaElement, PutItemError, PutItemInput, - QueryError, QueryInput, UpdateItemError, UpdateItemInput, -}; -use rusoto_sts::WebIdentityProvider; use url::Url; use errors::{DynamoDbConfigError, LockClientError}; @@ -53,7 +58,7 @@ impl LogStoreFactory for S3LogStoreFactory { )); } - let s3_options = S3StorageOptions::from_map(&options.0); + let s3_options = S3StorageOptions::from_map(&options.0)?; if s3_options.locking_provider.as_deref() != Some("dynamodb") { debug!("S3LogStoreFactory has been asked to create a LogStore without the dynamodb locking provider"); @@ -117,7 +122,7 @@ impl CommitEntry { /// Lock client backed by DynamoDb. pub struct DynamoDbLockClient { /// DynamoDb client - dynamodb_client: DynamoDbClient, + dynamodb_client: Client, /// configuration of the config: DynamoDbConfig, } @@ -131,24 +136,26 @@ impl std::fmt::Debug for DynamoDbLockClient { impl DynamoDbLockClient { /// Creates a new DynamoDbLockClient from the supplied storage options. pub fn try_new( + sdk_config: &SdkConfig, lock_table_name: Option, billing_mode: Option, max_elapsed_request_time: Option, - region: Region, - use_web_identity: bool, ) -> Result { - let dynamodb_client = create_dynamodb_client(region.clone(), use_web_identity)?; + let dynamodb_client = aws_sdk_dynamodb::Client::new(sdk_config); let lock_table_name = lock_table_name .or_else(|| std::env::var(constants::LOCK_TABLE_KEY_NAME).ok()) .unwrap_or(constants::DEFAULT_LOCK_TABLE_NAME.to_owned()); - let billing_mode = billing_mode + let billing_mode = if let Some(bm) = billing_mode .or_else(|| std::env::var(constants::BILLING_MODE_KEY_NAME).ok()) - .map_or_else( - || Ok(BillingMode::PayPerRequest), - |bm| BillingMode::from_str(&bm), - )?; + .as_ref() + { + BillingMode::try_parse(bm.to_ascii_uppercase().as_str()) + .map_err(|_| DynamoDbConfigError::InvalidBillingMode(String::default()))? + } else { + BillingMode::PayPerRequest + }; let max_elapsed_request_time = max_elapsed_request_time .or_else(|| std::env::var(constants::MAX_ELAPSED_REQUEST_TIME_KEY_NAME).ok()) @@ -162,8 +169,7 @@ impl DynamoDbLockClient { billing_mode, lock_table_name, max_elapsed_request_time, - use_web_identity, - region, + sdk_config: sdk_config.clone(), }; Ok(Self { dynamodb_client, @@ -179,40 +185,50 @@ impl DynamoDbLockClient { /// `active`, so transient failures might occurr when immediately using the lock client. pub async fn try_create_lock_table(&self) -> Result { let attribute_definitions = vec![ - AttributeDefinition { - attribute_name: constants::ATTR_TABLE_PATH.to_owned(), - attribute_type: constants::STRING_TYPE.to_owned(), - }, - AttributeDefinition { - attribute_name: constants::ATTR_FILE_NAME.to_owned(), - attribute_type: constants::STRING_TYPE.to_owned(), - }, + AttributeDefinition::builder() + .attribute_name(constants::ATTR_TABLE_PATH) + .attribute_type(ScalarAttributeType::S) + .build() + .unwrap(), + AttributeDefinition::builder() + .attribute_name(constants::ATTR_FILE_NAME) + .attribute_type(ScalarAttributeType::S) + .build() + .unwrap(), ]; - let input = CreateTableInput { - attribute_definitions, - key_schema: vec![ - KeySchemaElement { - attribute_name: constants::ATTR_TABLE_PATH.to_owned(), - key_type: constants::KEY_TYPE_HASH.to_owned(), - }, - KeySchemaElement { - attribute_name: constants::ATTR_FILE_NAME.to_owned(), - key_type: constants::KEY_TYPE_RANGE.to_owned(), - }, - ], - billing_mode: Some(self.config.billing_mode.to_str()), - table_name: self.config.lock_table_name.clone(), - ..Default::default() - }; - match self.dynamodb_client.create_table(input).await { + let request = self + .dynamodb_client + .create_table() + .set_attribute_definitions(Some(attribute_definitions)) + .set_key_schema(Some(vec![ + KeySchemaElement::builder() + .attribute_name(constants::ATTR_TABLE_PATH.to_owned()) + .key_type(KeyType::Hash) + .build() + .unwrap(), + KeySchemaElement::builder() + .attribute_name(constants::ATTR_FILE_NAME.to_owned()) + .key_type(KeyType::Range) + .build() + .unwrap(), + ])) + .billing_mode(self.config.billing_mode.clone()) + .table_name(&self.config.lock_table_name) + .send(); + match request.await { Ok(_) => Ok(CreateLockTableResult::TableCreated), - Err(RusotoError::Service(CreateTableError::ResourceInUse(_))) => { - Ok(CreateLockTableResult::TableAlreadyExists) - } - Err(reason) => Err(LockClientError::LockTableCreateFailure { - name: self.config.lock_table_name.clone(), - source: reason, - }), + Err(sdk_err) => match sdk_err.as_service_error() { + Some(CreateTableError::ResourceInUseException(_)) => { + Ok(CreateLockTableResult::TableAlreadyExists) + } + Some(_) => Err(LockClientError::LockTableCreateFailure { + name: self.config.lock_table_name.clone(), + source: Box::new(sdk_err.into_service_error()), + }), + _ => Err(LockClientError::GenericDynamoDb { + source: Box::new(sdk_err), + }), + }, } } @@ -238,22 +254,26 @@ impl DynamoDbLockClient { table_path: &str, version: i64, ) -> Result, LockClientError> { - let input = GetItemInput { - consistent_read: Some(true), - table_name: self.config.lock_table_name.clone(), - key: self.get_primary_key(version, table_path), - ..Default::default() - }; let item = self .retry(|| async { - match self.dynamodb_client.get_item(input.clone()).await { + match self + .dynamodb_client + .get_item() + .consistent_read(true) + .table_name(&self.config.lock_table_name) + .set_key(Some(self.get_primary_key(version, table_path))) + .send() + .await + { Ok(x) => Ok(x), - Err(RusotoError::Service(GetItemError::ProvisionedThroughputExceeded(_))) => { - Err(backoff::Error::transient( - LockClientError::ProvisionedThroughputExceeded, - )) - } - Err(err) => Err(backoff::Error::permanent(err.into())), + Err(sdk_err) => match sdk_err.as_service_error() { + Some(GetItemError::ProvisionedThroughputExceededException(_)) => { + Err(backoff::Error::transient( + LockClientError::ProvisionedThroughputExceeded, + )) + } + _ => Err(backoff::Error::permanent(sdk_err.into())), + }, } }) .await?; @@ -266,29 +286,33 @@ impl DynamoDbLockClient { table_path: &str, entry: &CommitEntry, ) -> Result<(), LockClientError> { - let item = create_value_map(entry, table_path); - let input = PutItemInput { - condition_expression: Some(constants::CONDITION_EXPR_CREATE.to_owned()), - table_name: self.get_lock_table_name(), - item, - ..Default::default() - }; self.retry(|| async { - match self.dynamodb_client.put_item(input.clone()).await { + let item = create_value_map(entry, table_path); + match self + .dynamodb_client + .put_item() + .condition_expression(constants::CONDITION_EXPR_CREATE.as_str()) + .table_name(self.get_lock_table_name()) + .set_item(Some(item)) + .send() + .await + { Ok(_) => Ok(()), - Err(RusotoError::Service(PutItemError::ProvisionedThroughputExceeded(_))) => Err( - backoff::Error::transient(LockClientError::ProvisionedThroughputExceeded), - ), - Err(RusotoError::Service(PutItemError::ConditionalCheckFailed(_))) => Err( - backoff::Error::permanent(LockClientError::VersionAlreadyExists { - table_path: table_path.to_owned(), - version: entry.version, - }), - ), - Err(RusotoError::Service(PutItemError::ResourceNotFound(_))) => Err( - backoff::Error::permanent(LockClientError::LockTableNotFound), - ), - Err(err) => Err(backoff::Error::permanent(err.into())), + Err(err) => match err.as_service_error() { + Some(PutItemError::ProvisionedThroughputExceededException(_)) => Err( + backoff::Error::transient(LockClientError::ProvisionedThroughputExceeded), + ), + Some(PutItemError::ConditionalCheckFailedException(_)) => Err( + backoff::Error::permanent(LockClientError::VersionAlreadyExists { + table_path: table_path.to_owned(), + version: entry.version, + }), + ), + Some(PutItemError::ResourceNotFoundException(_)) => Err( + backoff::Error::permanent(LockClientError::LockTableNotFound), + ), + _ => Err(backoff::Error::permanent(err.into())), + }, } }) .await @@ -312,25 +336,31 @@ impl DynamoDbLockClient { table_path: &str, limit: i64, ) -> Result, LockClientError> { - let input = QueryInput { - table_name: self.get_lock_table_name(), - consistent_read: Some(true), - limit: Some(limit), - scan_index_forward: Some(false), - key_condition_expression: Some(format!("{} = :tn", constants::ATTR_TABLE_PATH)), - expression_attribute_values: Some( - maplit::hashmap!(":tn".into() => string_attr(table_path)), - ), - ..Default::default() - }; let query_result = self .retry(|| async { - match self.dynamodb_client.query(input.clone()).await { + match self + .dynamodb_client + .query() + .table_name(self.get_lock_table_name()) + .consistent_read(true) + .limit(limit.try_into().unwrap_or(i32::MAX)) + .scan_index_forward(false) + .key_condition_expression(format!("{} = :tn", constants::ATTR_TABLE_PATH)) + .set_expression_attribute_values(Some( + maplit::hashmap!(":tn".into() => string_attr(table_path)), + )) + .send() + .await + { Ok(result) => Ok(result), - Err(RusotoError::Service(QueryError::ProvisionedThroughputExceeded(_))) => Err( - backoff::Error::transient(LockClientError::ProvisionedThroughputExceeded), - ), - Err(err) => Err(backoff::Error::permanent(err.into())), + Err(sdk_err) => match sdk_err.as_service_error() { + Some(QueryError::ProvisionedThroughputExceededException(_)) => { + Err(backoff::Error::transient( + LockClientError::ProvisionedThroughputExceeded, + )) + } + _ => Err(backoff::Error::permanent(sdk_err.into())), + }, } }) .await?; @@ -354,31 +384,32 @@ impl DynamoDbLockClient { .duration_since(SystemTime::UNIX_EPOCH) .unwrap() .as_secs(); - let input = UpdateItemInput { - table_name: self.get_lock_table_name(), - key: self.get_primary_key(version, table_path), - update_expression: Some("SET complete = :c, expireTime = :e".to_owned()), - expression_attribute_values: Some(maplit::hashmap! { - ":c".to_owned() => string_attr("true"), - ":e".to_owned() => num_attr(seconds_since_epoch), - ":f".into() => string_attr("false"), - }), - condition_expression: Some(constants::CONDITION_UPDATE_INCOMPLETE.to_owned()), - ..Default::default() - }; - self.retry(|| async { - match self.dynamodb_client.update_item(input.clone()).await { + match self + .dynamodb_client + .update_item() + .table_name(self.get_lock_table_name()) + .set_key(Some(self.get_primary_key(version, table_path))) + .update_expression("SET complete = :c, expireTime = :e".to_owned()) + .set_expression_attribute_values(Some(maplit::hashmap! { + ":c".to_owned() => string_attr("true"), + ":e".to_owned() => num_attr(seconds_since_epoch), + ":f".into() => string_attr("false"), + })) + .condition_expression(constants::CONDITION_UPDATE_INCOMPLETE) + .send() + .await + { Ok(_) => Ok(UpdateLogEntryResult::UpdatePerformed), - Err(RusotoError::Service(UpdateItemError::ConditionalCheckFailed(_))) => { - Ok(UpdateLogEntryResult::AlreadyCompleted) - } - Err(RusotoError::Service(UpdateItemError::ProvisionedThroughputExceeded(_))) => { - Err(backoff::Error::transient( - LockClientError::ProvisionedThroughputExceeded, - )) - } - Err(err) => Err(backoff::Error::permanent(err.into())), + Err(err) => match err.as_service_error() { + Some(UpdateItemError::ProvisionedThroughputExceededException(_)) => Err( + backoff::Error::transient(LockClientError::ProvisionedThroughputExceeded), + ), + Some(UpdateItemError::ConditionalCheckFailedException(_)) => { + Ok(UpdateLogEntryResult::AlreadyCompleted) + } + _ => Err(backoff::Error::permanent(err.into())), + }, } }) .await @@ -467,40 +498,23 @@ fn create_value_map( value_map } -#[derive(Debug, PartialEq)] -pub enum BillingMode { - PayPerRequest, - Provisioned, -} - -impl BillingMode { - fn to_str(&self) -> String { - match self { - Self::PayPerRequest => "PAY_PER_REQUEST".to_owned(), - Self::Provisioned => "PROVISIONED".to_owned(), - } - } -} - -impl FromStr for BillingMode { - type Err = DynamoDbConfigError; - - fn from_str(s: &str) -> Result { - match s.to_ascii_lowercase().as_str() { - "provisioned" => Ok(BillingMode::Provisioned), - "pay_per_request" => Ok(BillingMode::PayPerRequest), - _ => Err(DynamoDbConfigError::InvalidBillingMode(s.to_owned())), - } - } -} - -#[derive(Debug, PartialEq)] +#[derive(Debug)] pub struct DynamoDbConfig { pub billing_mode: BillingMode, pub lock_table_name: String, pub max_elapsed_request_time: Duration, - pub use_web_identity: bool, - pub region: Region, + pub sdk_config: SdkConfig, +} + +impl Eq for DynamoDbConfig {} +impl PartialEq for DynamoDbConfig { + fn eq(&self, other: &Self) -> bool { + self.billing_mode == other.billing_mode + && self.lock_table_name == other.lock_table_name + && self.max_elapsed_request_time == other.max_elapsed_request_time + && self.sdk_config.endpoint_url() == other.sdk_config.endpoint_url() + && self.sdk_config.region() == other.sdk_config.region() + } } /// Represents the possible, positive outcomes of calling `DynamoDbClient::try_create_lock_table()` @@ -545,23 +559,6 @@ pub mod constants { pub const DEFAULT_COMMIT_ENTRY_EXPIRATION_DELAY: Duration = Duration::from_secs(86_400); } -fn create_dynamodb_client( - region: Region, - use_web_identity: bool, -) -> Result { - Ok(match use_web_identity { - true => { - let dispatcher = HttpClient::new()?; - rusoto_dynamodb::DynamoDbClient::new_with( - dispatcher, - get_web_identity_provider()?, - region, - ) - } - false => rusoto_dynamodb::DynamoDbClient::new(region), - }) -} - /// Extract a field from an item's attribute value map, producing a descriptive error /// of the various failure cases. fn extract_required_string_field<'a>( @@ -573,12 +570,11 @@ fn extract_required_string_field<'a>( .ok_or_else(|| LockClientError::InconsistentData { description: format!("mandatory string field '{field_name}' missing"), })? - .s - .as_ref() - .ok_or_else(|| LockClientError::InconsistentData { + .as_s() + .map_err(|v| LockClientError::InconsistentData { description: format!( "mandatory string field '{field_name}' exists, but is not a string: {:#?}", - fields.get(field_name) + v, ), }) .map(|s| s.as_str()) @@ -593,35 +589,21 @@ fn extract_optional_number_field<'a>( fields .get(field_name) .map(|attr| { - attr.n - .as_ref() - .ok_or_else(|| LockClientError::InconsistentData { - description: format!( - "field with name '{field_name}' exists, but is not of type number" - ), - }) + attr.as_n().map_err(|_| LockClientError::InconsistentData { + description: format!( + "field with name '{field_name}' exists, but is not of type number" + ), + }) }) .transpose() } fn string_attr(s: T) -> AttributeValue { - AttributeValue { - s: Some(s.to_string()), - ..Default::default() - } + AttributeValue::S(s.to_string()) } fn num_attr(n: T) -> AttributeValue { - AttributeValue { - n: Some(n.to_string()), - ..Default::default() - } -} - -fn get_web_identity_provider( -) -> Result, DynamoDbConfigError> { - let provider = WebIdentityProvider::from_k8s_env(); - Ok(AutoRefreshingProvider::new(provider)?) + AttributeValue::N(n.to_string()) } lazy_static! { diff --git a/crates/aws/src/logstore.rs b/crates/aws/src/logstore.rs index 123aadd2d1..a9a98dbe18 100644 --- a/crates/aws/src/logstore.rs +++ b/crates/aws/src/logstore.rs @@ -45,6 +45,7 @@ impl S3DynamoDbLogStore { object_store: ObjectStoreRef, ) -> DeltaResult { let lock_client = DynamoDbLockClient::try_new( + &s3_options.sdk_config, s3_options .extra_opts .get(constants::LOCK_TABLE_KEY_NAME) @@ -57,13 +58,11 @@ impl S3DynamoDbLogStore { .extra_opts .get(constants::MAX_ELAPSED_REQUEST_TIME_KEY_NAME) .cloned(), - s3_options.region.clone(), - s3_options.use_web_identity, ) .map_err(|err| DeltaTableError::ObjectStore { source: ObjectStoreError::Generic { store: STORE_NAME, - source: err.into(), + source: Box::new(err), }, })?; let table_path = to_uri(&location, &Path::from("")); diff --git a/crates/aws/src/native.rs b/crates/aws/src/native.rs new file mode 100644 index 0000000000..c647194eb7 --- /dev/null +++ b/crates/aws/src/native.rs @@ -0,0 +1,12 @@ +use aws_sdk_sts::config::SharedHttpClient; +use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder; + +pub fn use_native_tls_client(allow_http: bool) -> SharedHttpClient { + let mut tls_connector = hyper_tls::HttpsConnector::new(); + if allow_http { + tls_connector.https_only(false); + } + + let client = HyperClientBuilder::new().build(tls_connector); + client +} diff --git a/crates/aws/src/storage.rs b/crates/aws/src/storage.rs index 87d488b54f..ffd7bb6996 100644 --- a/crates/aws/src/storage.rs +++ b/crates/aws/src/storage.rs @@ -1,5 +1,6 @@ //! AWS S3 storage backend. +use aws_config::{Region, SdkConfig}; use bytes::Bytes; use deltalake_core::storage::object_store::{ aws::AmazonS3ConfigKey, parse_url_opts, GetOptions, GetResult, ListResult, MultipartId, @@ -8,7 +9,7 @@ use deltalake_core::storage::object_store::{ use deltalake_core::storage::{str_is_truthy, ObjectStoreFactory, ObjectStoreRef, StorageOptions}; use deltalake_core::{DeltaResult, ObjectStoreError, Path}; use futures::stream::BoxStream; -use rusoto_core::Region; +use futures::Future; use std::collections::HashMap; use std::fmt::Debug; use std::ops::Range; @@ -18,6 +19,10 @@ use std::time::Duration; use tokio::io::AsyncWrite; use url::Url; +use crate::errors::DynamoDbConfigError; +#[cfg(feature = "native-tls")] +use crate::native; + const STORE_NAME: &str = "DeltaS3ObjectStore"; #[derive(Clone, Default, Debug)] @@ -25,7 +30,21 @@ pub struct S3ObjectStoreFactory {} impl S3ObjectStoreFactory { fn with_env_s3(&self, options: &StorageOptions) -> StorageOptions { - let mut options = options.clone(); + let mut options = StorageOptions( + options + .0 + .clone() + .into_iter() + .map(|(k, v)| { + if let Ok(config_key) = AmazonS3ConfigKey::from_str(&k.to_ascii_lowercase()) { + (config_key.as_ref().to_string(), v) + } else { + (k, v) + } + }) + .collect(), + ); + for (os_key, os_value) in std::env::vars_os() { if let (Some(key), Some(value)) = (os_key.to_str(), os_value.to_str()) { if let Ok(config_key) = AmazonS3ConfigKey::from_str(&key.to_ascii_lowercase()) { @@ -64,7 +83,7 @@ impl ObjectStoreFactory for S3ObjectStoreFactory { return Ok((Arc::from(store), prefix)); } - let options = S3StorageOptions::from_map(&options.0); + let options = S3StorageOptions::from_map(&options.0)?; let store = S3StorageBackend::try_new( store.into(), @@ -78,30 +97,38 @@ impl ObjectStoreFactory for S3ObjectStoreFactory { /// Options used to configure the [S3StorageBackend]. /// /// Available options are described in [s3_constants]. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] #[allow(missing_docs)] pub struct S3StorageOptions { - pub endpoint_url: Option, - pub region: Region, - pub profile: Option, - pub aws_access_key_id: Option, - pub aws_secret_access_key: Option, - pub aws_session_token: Option, pub virtual_hosted_style_request: bool, pub locking_provider: Option, - pub assume_role_arn: Option, - pub assume_role_session_name: Option, - pub use_web_identity: bool, pub s3_pool_idle_timeout: Duration, pub sts_pool_idle_timeout: Duration, pub s3_get_internal_server_error_retries: usize, pub allow_unsafe_rename: bool, pub extra_opts: HashMap, + pub sdk_config: SdkConfig, +} + +impl Eq for S3StorageOptions {} +impl PartialEq for S3StorageOptions { + fn eq(&self, other: &Self) -> bool { + self.virtual_hosted_style_request == other.virtual_hosted_style_request + && self.locking_provider == other.locking_provider + && self.s3_pool_idle_timeout == other.s3_pool_idle_timeout + && self.sts_pool_idle_timeout == other.sts_pool_idle_timeout + && self.s3_get_internal_server_error_retries + == other.s3_get_internal_server_error_retries + && self.allow_unsafe_rename == other.allow_unsafe_rename + && self.extra_opts == other.extra_opts + && self.sdk_config.endpoint_url() == other.sdk_config.endpoint_url() + && self.sdk_config.region() == other.sdk_config.region() + } } impl S3StorageOptions { /// Creates an instance of S3StorageOptions from the given HashMap. - pub fn from_map(options: &HashMap) -> S3StorageOptions { + pub fn from_map(options: &HashMap) -> DeltaResult { let extra_opts = options .iter() .filter(|(k, _)| !s3_constants::S3_OPTS.contains(&k.as_str())) @@ -118,18 +145,6 @@ impl S3StorageOptions { Self::ensure_env_var(options, s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE); Self::ensure_env_var(options, s3_constants::AWS_ROLE_ARN); Self::ensure_env_var(options, s3_constants::AWS_ROLE_SESSION_NAME); - - let endpoint_url = str_option(options, s3_constants::AWS_ENDPOINT_URL); - let region = if let Some(endpoint_url) = endpoint_url.as_ref() { - Region::Custom { - name: Self::str_or_default(options, s3_constants::AWS_REGION, "custom".to_string()), - endpoint: endpoint_url.to_owned(), - } - } else { - Region::default() - }; - let profile = str_option(options, s3_constants::AWS_PROFILE); - let s3_pool_idle_timeout = Self::u64_or_default(options, s3_constants::AWS_S3_POOL_IDLE_TIMEOUT_SECONDS, 15); let sts_pool_idle_timeout = @@ -150,30 +165,43 @@ impl S3StorageOptions { .map(|val| str_is_truthy(&val)) .unwrap_or(false); - Self { - endpoint_url, - region, - profile, - aws_access_key_id: str_option(options, s3_constants::AWS_ACCESS_KEY_ID), - aws_secret_access_key: str_option(options, s3_constants::AWS_SECRET_ACCESS_KEY), - aws_session_token: str_option(options, s3_constants::AWS_SESSION_TOKEN), + #[cfg(feature = "native-tls")] + let sdk_config = execute_sdk_future( + aws_config::ConfigLoader::default() + .http_client(native::use_native_tls_client( + str_option(options, s3_constants::AWS_ALLOW_HTTP) + .map(|val| str_is_truthy(&val)) + .unwrap_or(false), + )) + .load(), + )?; + #[cfg(feature = "rustls")] + let sdk_config = execute_sdk_future(aws_config::load_from_env())?; + + let sdk_config = + if let Some(endpoint_url) = str_option(options, s3_constants::AWS_ENDPOINT_URL) { + sdk_config.to_builder().endpoint_url(endpoint_url).build() + } else { + sdk_config + }; + Ok(Self { virtual_hosted_style_request, locking_provider: str_option(options, s3_constants::AWS_S3_LOCKING_PROVIDER), - assume_role_arn: str_option(options, s3_constants::AWS_S3_ASSUME_ROLE_ARN), - assume_role_session_name: str_option(options, s3_constants::AWS_S3_ROLE_SESSION_NAME), - use_web_identity: std::env::var(s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE).is_ok(), s3_pool_idle_timeout: Duration::from_secs(s3_pool_idle_timeout), sts_pool_idle_timeout: Duration::from_secs(sts_pool_idle_timeout), s3_get_internal_server_error_retries, allow_unsafe_rename, extra_opts, - } + sdk_config, + }) } - fn str_or_default(map: &HashMap, key: &str, default: String) -> String { - map.get(key) - .map(|v| v.to_owned()) - .unwrap_or_else(|| std::env::var(key).unwrap_or(default)) + pub fn endpoint_url(&self) -> Option<&str> { + self.sdk_config.endpoint_url() + } + + pub fn region(&self) -> Option<&Region> { + self.sdk_config.region() } fn u64_or_default(map: &HashMap, key: &str, default: u64) -> u64 { @@ -187,15 +215,45 @@ impl S3StorageOptions { std::env::set_var(key, val); } } -} -impl Default for S3StorageOptions { - /// Creates an instance of S3StorageOptions from environment variables. - fn default() -> S3StorageOptions { + pub fn try_default() -> DeltaResult { Self::from_map(&HashMap::new()) } } +fn execute_sdk_future + Send + 'static>( + future: F, +) -> DeltaResult { + match tokio::runtime::Handle::try_current() { + Ok(handle) => match handle.runtime_flavor() { + tokio::runtime::RuntimeFlavor::MultiThread => { + Ok(tokio::task::block_in_place(move || handle.block_on(future))) + } + _ => { + let mut cfg: Option = None; + std::thread::scope(|scope| { + scope.spawn(|| { + cfg = Some(handle.block_on(future)); + }); + }); + cfg.ok_or(deltalake_core::DeltaTableError::ObjectStore { + source: ObjectStoreError::Generic { + store: STORE_NAME, + source: Box::new(DynamoDbConfigError::InitializationError), + }, + }) + } + }, + Err(_) => { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("a tokio runtime is required by the AWS sdk"); + Ok(runtime.block_on(future)) + } + } +} + /// An S3 implementation of the [ObjectStore] trait pub struct S3StorageBackend { inner: ObjectStoreRef, @@ -409,201 +467,281 @@ mod tests { use maplit::hashmap; use serial_test::serial; + struct ScopedEnv { + vars: HashMap, + } + + impl ScopedEnv { + pub fn new() -> Self { + let vars = std::env::vars_os().collect(); + Self { vars } + } + + pub fn run(mut f: impl FnMut() -> T) -> T { + let _env_scope = Self::new(); + f() + } + } + + impl Drop for ScopedEnv { + fn drop(&mut self) { + let to_remove: Vec<_> = std::env::vars_os() + .map(|kv| kv.0) + .filter(|k| !self.vars.contains_key(k)) + .collect(); + for k in to_remove { + std::env::remove_var(k); + } + for (key, value) in self.vars.drain() { + std::env::set_var(key, value); + } + } + } + + fn clear_env_of_aws_keys() { + let keys_to_clear = std::env::vars().filter_map(|(k, _v)| { + if AmazonS3ConfigKey::from_str(&k.to_ascii_lowercase()).is_ok() { + Some(k) + } else { + None + } + }); + + for k in keys_to_clear { + std::env::remove_var(k); + } + } + #[test] #[serial] fn storage_options_default_test() { - std::env::set_var(s3_constants::AWS_ENDPOINT_URL, "http://localhost"); - std::env::set_var(s3_constants::AWS_REGION, "us-west-1"); - std::env::set_var(s3_constants::AWS_PROFILE, "default"); - std::env::set_var(s3_constants::AWS_ACCESS_KEY_ID, "default_key_id"); - std::env::set_var(s3_constants::AWS_SECRET_ACCESS_KEY, "default_secret_key"); - std::env::set_var(s3_constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); - std::env::set_var( - s3_constants::AWS_S3_ASSUME_ROLE_ARN, - "arn:aws:iam::123456789012:role/some_role", - ); - std::env::set_var(s3_constants::AWS_S3_ROLE_SESSION_NAME, "session_name"); - std::env::set_var(s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); - std::env::remove_var(s3_constants::AWS_S3_POOL_IDLE_TIMEOUT_SECONDS); - std::env::remove_var(s3_constants::AWS_STS_POOL_IDLE_TIMEOUT_SECONDS); - std::env::remove_var(s3_constants::AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES); - - let options = S3StorageOptions::default(); - - assert_eq!( - S3StorageOptions { - endpoint_url: Some("http://localhost".to_string()), - region: Region::Custom { - name: "us-west-1".to_string(), - endpoint: "http://localhost".to_string() + ScopedEnv::run(|| { + clear_env_of_aws_keys(); + + std::env::set_var(s3_constants::AWS_ENDPOINT_URL, "http://localhost"); + std::env::set_var(s3_constants::AWS_REGION, "us-west-1"); + std::env::set_var(s3_constants::AWS_PROFILE, "default"); + std::env::set_var(s3_constants::AWS_ACCESS_KEY_ID, "default_key_id"); + std::env::set_var(s3_constants::AWS_SECRET_ACCESS_KEY, "default_secret_key"); + std::env::set_var(s3_constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); + std::env::set_var( + s3_constants::AWS_S3_ASSUME_ROLE_ARN, + "arn:aws:iam::123456789012:role/some_role", + ); + std::env::set_var(s3_constants::AWS_S3_ROLE_SESSION_NAME, "session_name"); + std::env::set_var(s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); + + let options = S3StorageOptions::try_default().unwrap(); + assert_eq!( + S3StorageOptions { + sdk_config: SdkConfig::builder() + .endpoint_url("http://localhost".to_string()) + .region(Region::from_static("us-west-1")) + .build(), + virtual_hosted_style_request: false, + locking_provider: Some("dynamodb".to_string()), + s3_pool_idle_timeout: Duration::from_secs(15), + sts_pool_idle_timeout: Duration::from_secs(10), + s3_get_internal_server_error_retries: 10, + extra_opts: HashMap::new(), + allow_unsafe_rename: false, }, - profile: Some("default".to_string()), - aws_access_key_id: Some("default_key_id".to_string()), - aws_secret_access_key: Some("default_secret_key".to_string()), - aws_session_token: None, - virtual_hosted_style_request: false, - assume_role_arn: Some("arn:aws:iam::123456789012:role/some_role".to_string()), - assume_role_session_name: Some("session_name".to_string()), - use_web_identity: true, - locking_provider: Some("dynamodb".to_string()), - s3_pool_idle_timeout: Duration::from_secs(15), - sts_pool_idle_timeout: Duration::from_secs(10), - s3_get_internal_server_error_retries: 10, - extra_opts: HashMap::new(), - allow_unsafe_rename: false, - }, - options - ); + options + ); + }); } #[test] #[serial] fn storage_options_with_only_region_and_credentials() { - std::env::remove_var(s3_constants::AWS_ENDPOINT_URL); - let options = S3StorageOptions::from_map(&hashmap! { - s3_constants::AWS_REGION.to_string() => "eu-west-1".to_string(), - s3_constants::AWS_ACCESS_KEY_ID.to_string() => "test".to_string(), - s3_constants::AWS_SECRET_ACCESS_KEY.to_string() => "test_secret".to_string(), - }); + ScopedEnv::run(|| { + clear_env_of_aws_keys(); + std::env::remove_var(s3_constants::AWS_ENDPOINT_URL); + let options = S3StorageOptions::from_map(&hashmap! { + s3_constants::AWS_REGION.to_string() => "eu-west-1".to_string(), + s3_constants::AWS_ACCESS_KEY_ID.to_string() => "test".to_string(), + s3_constants::AWS_SECRET_ACCESS_KEY.to_string() => "test_secret".to_string(), + }) + .unwrap(); - assert_eq!( - S3StorageOptions { - endpoint_url: None, - region: Region::default(), - aws_access_key_id: Some("test".to_string()), - aws_secret_access_key: Some("test_secret".to_string()), - ..Default::default() - }, - options - ); + let mut expected = S3StorageOptions::try_default().unwrap(); + expected.sdk_config = SdkConfig::builder() + .region(Region::from_static("eu-west-1")) + .build(); + assert_eq!(expected, options); + }); } #[test] #[serial] fn storage_options_from_map_test() { - let options = S3StorageOptions::from_map(&hashmap! { - s3_constants::AWS_ENDPOINT_URL.to_string() => "http://localhost:1234".to_string(), - s3_constants::AWS_REGION.to_string() => "us-west-2".to_string(), - s3_constants::AWS_PROFILE.to_string() => "default".to_string(), - s3_constants::AWS_S3_ADDRESSING_STYLE.to_string() => "virtual".to_string(), - s3_constants::AWS_S3_LOCKING_PROVIDER.to_string() => "another_locking_provider".to_string(), - s3_constants::AWS_S3_ASSUME_ROLE_ARN.to_string() => "arn:aws:iam::123456789012:role/another_role".to_string(), - s3_constants::AWS_S3_ROLE_SESSION_NAME.to_string() => "another_session_name".to_string(), - s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE.to_string() => "another_token_file".to_string(), - s3_constants::AWS_S3_POOL_IDLE_TIMEOUT_SECONDS.to_string() => "1".to_string(), - s3_constants::AWS_STS_POOL_IDLE_TIMEOUT_SECONDS.to_string() => "2".to_string(), - s3_constants::AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES.to_string() => "3".to_string(), - s3_constants::AWS_ACCESS_KEY_ID.to_string() => "test_id".to_string(), - s3_constants::AWS_SECRET_ACCESS_KEY.to_string() => "test_secret".to_string(), - }); - - assert_eq!( - S3StorageOptions { - endpoint_url: Some("http://localhost:1234".to_string()), - region: Region::Custom { - name: "us-west-2".to_string(), - endpoint: "http://localhost:1234".to_string() + ScopedEnv::run(|| { + clear_env_of_aws_keys(); + let options = S3StorageOptions::from_map(&hashmap! { + s3_constants::AWS_ENDPOINT_URL.to_string() => "http://localhost:1234".to_string(), + s3_constants::AWS_REGION.to_string() => "us-west-2".to_string(), + s3_constants::AWS_PROFILE.to_string() => "default".to_string(), + s3_constants::AWS_S3_ADDRESSING_STYLE.to_string() => "virtual".to_string(), + s3_constants::AWS_S3_LOCKING_PROVIDER.to_string() => "another_locking_provider".to_string(), + s3_constants::AWS_S3_ASSUME_ROLE_ARN.to_string() => "arn:aws:iam::123456789012:role/another_role".to_string(), + s3_constants::AWS_S3_ROLE_SESSION_NAME.to_string() => "another_session_name".to_string(), + s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE.to_string() => "another_token_file".to_string(), + s3_constants::AWS_S3_POOL_IDLE_TIMEOUT_SECONDS.to_string() => "1".to_string(), + s3_constants::AWS_STS_POOL_IDLE_TIMEOUT_SECONDS.to_string() => "2".to_string(), + s3_constants::AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES.to_string() => "3".to_string(), + s3_constants::AWS_ACCESS_KEY_ID.to_string() => "test_id".to_string(), + s3_constants::AWS_SECRET_ACCESS_KEY.to_string() => "test_secret".to_string(), + }).unwrap(); + + assert_eq!( + S3StorageOptions { + sdk_config: SdkConfig::builder() + .endpoint_url("http://localhost:1234".to_string()) + .region(Region::from_static("us-west-2")) + .build(), + virtual_hosted_style_request: true, + locking_provider: Some("another_locking_provider".to_string()), + s3_pool_idle_timeout: Duration::from_secs(1), + sts_pool_idle_timeout: Duration::from_secs(2), + s3_get_internal_server_error_retries: 3, + extra_opts: hashmap! { + s3_constants::AWS_S3_ADDRESSING_STYLE.to_string() => "virtual".to_string() + }, + allow_unsafe_rename: false, }, - profile: Some("default".to_string()), - aws_access_key_id: Some("test_id".to_string()), - aws_secret_access_key: Some("test_secret".to_string()), - aws_session_token: None, - virtual_hosted_style_request: true, - assume_role_arn: Some("arn:aws:iam::123456789012:role/another_role".to_string()), - assume_role_session_name: Some("another_session_name".to_string()), - use_web_identity: true, - locking_provider: Some("another_locking_provider".to_string()), - s3_pool_idle_timeout: Duration::from_secs(1), - sts_pool_idle_timeout: Duration::from_secs(2), - s3_get_internal_server_error_retries: 3, - extra_opts: hashmap! { - s3_constants::AWS_S3_ADDRESSING_STYLE.to_string() => "virtual".to_string() - }, - allow_unsafe_rename: false, - }, - options - ); + options + ); + }); } #[test] #[serial] fn storage_options_mixed_test() { - std::env::set_var(s3_constants::AWS_ENDPOINT_URL, "http://localhost"); - std::env::set_var(s3_constants::AWS_REGION, "us-west-1"); - std::env::set_var(s3_constants::AWS_PROFILE, "default"); - std::env::set_var(s3_constants::AWS_ACCESS_KEY_ID, "wrong_key_id"); - std::env::set_var(s3_constants::AWS_SECRET_ACCESS_KEY, "wrong_secret_key"); - std::env::set_var(s3_constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); - std::env::set_var( - s3_constants::AWS_S3_ASSUME_ROLE_ARN, - "arn:aws:iam::123456789012:role/some_role", - ); - std::env::set_var(s3_constants::AWS_S3_ROLE_SESSION_NAME, "session_name"); - std::env::set_var(s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); - - std::env::set_var(s3_constants::AWS_S3_POOL_IDLE_TIMEOUT_SECONDS, "1"); - std::env::set_var(s3_constants::AWS_STS_POOL_IDLE_TIMEOUT_SECONDS, "2"); - std::env::set_var(s3_constants::AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES, "3"); - let options = S3StorageOptions::from_map(&hashmap! { - s3_constants::AWS_ACCESS_KEY_ID.to_string() => "test_id_mixed".to_string(), - s3_constants::AWS_SECRET_ACCESS_KEY.to_string() => "test_secret_mixed".to_string(), - s3_constants::AWS_REGION.to_string() => "us-west-2".to_string(), - "AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES".to_string() => "3".to_string(), - }); - - assert_eq!( - S3StorageOptions { - endpoint_url: Some("http://localhost".to_string()), - region: Region::Custom { - name: "us-west-2".to_string(), - endpoint: "http://localhost".to_string() + ScopedEnv::run(|| { + clear_env_of_aws_keys(); + std::env::set_var(s3_constants::AWS_ENDPOINT_URL, "http://localhost"); + std::env::set_var(s3_constants::AWS_REGION, "us-west-1"); + std::env::set_var(s3_constants::AWS_PROFILE, "default"); + std::env::set_var(s3_constants::AWS_ACCESS_KEY_ID, "wrong_key_id"); + std::env::set_var(s3_constants::AWS_SECRET_ACCESS_KEY, "wrong_secret_key"); + std::env::set_var(s3_constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); + std::env::set_var( + s3_constants::AWS_S3_ASSUME_ROLE_ARN, + "arn:aws:iam::123456789012:role/some_role", + ); + std::env::set_var(s3_constants::AWS_S3_ROLE_SESSION_NAME, "session_name"); + std::env::set_var(s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); + + std::env::set_var(s3_constants::AWS_S3_POOL_IDLE_TIMEOUT_SECONDS, "1"); + std::env::set_var(s3_constants::AWS_STS_POOL_IDLE_TIMEOUT_SECONDS, "2"); + std::env::set_var(s3_constants::AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES, "3"); + let options = S3StorageOptions::from_map(&hashmap! { + s3_constants::AWS_ACCESS_KEY_ID.to_string() => "test_id_mixed".to_string(), + s3_constants::AWS_SECRET_ACCESS_KEY.to_string() => "test_secret_mixed".to_string(), + s3_constants::AWS_REGION.to_string() => "us-west-2".to_string(), + "AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES".to_string() => "3".to_string(), + }) + .unwrap(); + + assert_eq!( + S3StorageOptions { + sdk_config: SdkConfig::builder() + .endpoint_url("http://localhost".to_string()) + .region(Region::from_static("us-west-2")) + .build(), + virtual_hosted_style_request: false, + locking_provider: Some("dynamodb".to_string()), + s3_pool_idle_timeout: Duration::from_secs(1), + sts_pool_idle_timeout: Duration::from_secs(2), + s3_get_internal_server_error_retries: 3, + extra_opts: hashmap! {}, + allow_unsafe_rename: false, }, - profile: Some("default".to_string()), - aws_access_key_id: Some("test_id_mixed".to_string()), - aws_secret_access_key: Some("test_secret_mixed".to_string()), - aws_session_token: None, - virtual_hosted_style_request: false, - assume_role_arn: Some("arn:aws:iam::123456789012:role/some_role".to_string()), - assume_role_session_name: Some("session_name".to_string()), - use_web_identity: true, - locking_provider: Some("dynamodb".to_string()), - s3_pool_idle_timeout: Duration::from_secs(1), - sts_pool_idle_timeout: Duration::from_secs(2), - s3_get_internal_server_error_retries: 3, - extra_opts: hashmap! {}, - allow_unsafe_rename: false, - }, - options - ); + options + ); + }); } + #[test] #[serial] fn storage_options_web_identity_test() { - let _options = S3StorageOptions::from_map(&hashmap! { - s3_constants::AWS_REGION.to_string() => "eu-west-1".to_string(), - s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE.to_string() => "web_identity_token_file".to_string(), - s3_constants::AWS_ROLE_ARN.to_string() => "arn:aws:iam::123456789012:role/web_identity_role".to_string(), - s3_constants::AWS_ROLE_SESSION_NAME.to_string() => "web_identity_session_name".to_string(), + ScopedEnv::run(|| { + clear_env_of_aws_keys(); + let _options = S3StorageOptions::from_map(&hashmap! { + s3_constants::AWS_REGION.to_string() => "eu-west-1".to_string(), + s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE.to_string() => "web_identity_token_file".to_string(), + s3_constants::AWS_ROLE_ARN.to_string() => "arn:aws:iam::123456789012:role/web_identity_role".to_string(), + s3_constants::AWS_ROLE_SESSION_NAME.to_string() => "web_identity_session_name".to_string(), + }).unwrap(); + + assert_eq!( + "eu-west-1", + std::env::var(s3_constants::AWS_REGION).unwrap() + ); + + assert_eq!( + "web_identity_token_file", + std::env::var(s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE).unwrap() + ); + + assert_eq!( + "arn:aws:iam::123456789012:role/web_identity_role", + std::env::var(s3_constants::AWS_ROLE_ARN).unwrap() + ); + + assert_eq!( + "web_identity_session_name", + std::env::var(s3_constants::AWS_ROLE_SESSION_NAME).unwrap() + ); }); + } - assert_eq!( - "eu-west-1", - std::env::var(s3_constants::AWS_REGION).unwrap() - ); + #[test] + #[serial] + fn when_merging_with_env_unsupplied_options_are_added() { + ScopedEnv::run(|| { + clear_env_of_aws_keys(); + let raw_options = hashmap! {}; - assert_eq!( - "web_identity_token_file", - std::env::var(s3_constants::AWS_WEB_IDENTITY_TOKEN_FILE).unwrap() - ); + std::env::set_var(s3_constants::AWS_ACCESS_KEY_ID, "env_key"); + std::env::set_var(s3_constants::AWS_ENDPOINT_URL, "env_key"); + std::env::set_var(s3_constants::AWS_SECRET_ACCESS_KEY, "env_key"); + std::env::set_var(s3_constants::AWS_REGION, "env_key"); - assert_eq!( - "arn:aws:iam::123456789012:role/web_identity_role", - std::env::var(s3_constants::AWS_ROLE_ARN).unwrap() - ); + let combined_options = + S3ObjectStoreFactory {}.with_env_s3(&StorageOptions(raw_options)); - assert_eq!( - "web_identity_session_name", - std::env::var(s3_constants::AWS_ROLE_SESSION_NAME).unwrap() - ); + assert_eq!(combined_options.0.len(), 4); + + for v in combined_options.0.values() { + assert_eq!(v, "env_key"); + } + }); + } + + #[tokio::test] + #[serial] + async fn when_merging_with_env_supplied_options_take_precedence() { + ScopedEnv::run(|| { + clear_env_of_aws_keys(); + let raw_options = hashmap! { + "AWS_ACCESS_KEY_ID".to_string() => "options_key".to_string(), + "AWS_ENDPOINT_URL".to_string() => "options_key".to_string(), + "AWS_SECRET_ACCESS_KEY".to_string() => "options_key".to_string(), + "AWS_REGION".to_string() => "options_key".to_string() + }; + + std::env::set_var("aws_access_key_id", "env_key"); + std::env::set_var("aws_endpoint", "env_key"); + std::env::set_var("aws_secret_access_key", "env_key"); + std::env::set_var("aws_region", "env_key"); + + let combined_options = + S3ObjectStoreFactory {}.with_env_s3(&StorageOptions(raw_options)); + + for v in combined_options.0.values() { + assert_eq!(v, "options_key"); + } + }); } } diff --git a/crates/aws/tests/integration_s3_dynamodb.rs b/crates/aws/tests/integration_s3_dynamodb.rs index 179c46fc5a..880fd47931 100644 --- a/crates/aws/tests/integration_s3_dynamodb.rs +++ b/crates/aws/tests/integration_s3_dynamodb.rs @@ -5,6 +5,8 @@ use std::collections::HashMap; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use aws_config::SdkConfig; +use aws_sdk_dynamodb::types::BillingMode; use deltalake_aws::logstore::{RepairLogEntryResult, S3DynamoDbLogStore}; use deltalake_aws::storage::S3StorageOptions; use deltalake_aws::{CommitEntry, DynamoDbConfig, DynamoDbLockClient}; @@ -31,17 +33,16 @@ lazy_static! { static ref OPTIONS: HashMap = maplit::hashmap! { "allow_http".to_owned() => "true".to_owned(), }; - static ref S3_OPTIONS: S3StorageOptions = S3StorageOptions::from_map(&OPTIONS); + static ref S3_OPTIONS: S3StorageOptions = S3StorageOptions::from_map(&OPTIONS).unwrap(); } fn make_client() -> TestResult { - let options: S3StorageOptions = S3StorageOptions::default(); + let options: S3StorageOptions = S3StorageOptions::try_default().unwrap(); Ok(DynamoDbLockClient::try_new( + &options.sdk_config, None, None, None, - options.region.clone(), - false, )?) } @@ -62,13 +63,13 @@ fn client_configs_via_env_variables() -> TestResult<()> { ); let client = make_client()?; let config = client.get_dynamodb_config(); + let options: S3StorageOptions = S3StorageOptions::try_default().unwrap(); assert_eq!( DynamoDbConfig { - billing_mode: deltalake_aws::BillingMode::PayPerRequest, + billing_mode: BillingMode::PayPerRequest, lock_table_name: "some_table".to_owned(), max_elapsed_request_time: Duration::from_secs(64), - use_web_identity: false, - region: config.region.clone(), + sdk_config: options.sdk_config, }, *config, ); @@ -208,7 +209,9 @@ async fn test_concurrent_writers() -> TestResult<()> { for f in futures { map.extend(f.await?); } + validate_lock_table_state(&table, WORKERS * COMMITS).await?; + Ok(()) } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index f9df57f794..8a2f2d6817 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deltalake-core" -version = "0.17.0" +version = "0.17.1" authors.workspace = true keywords.workspace = true readme.workspace = true @@ -12,9 +12,7 @@ repository.workspace = true rust-version.workspace = true [package.metadata.docs.rs] -# We cannot use all_features because TLS features are mutually exclusive. -# We cannot use hdfs feature because it requires Java to be installed. -features = ["azure", "datafusion", "gcs", "hdfs", "json", "python", "s3", "unity-experimental"] +features = ["datafusion", "json", "unity-experimental"] [dependencies] # arrow diff --git a/crates/core/src/kernel/arrow/mod.rs b/crates/core/src/kernel/arrow/mod.rs index 99809c9c24..7dbac2854d 100644 --- a/crates/core/src/kernel/arrow/mod.rs +++ b/crates/core/src/kernel/arrow/mod.rs @@ -134,8 +134,11 @@ impl TryFrom<&DataType> for ArrowDataType { // timezone. Stored as 4 bytes integer representing days since 1970-01-01 Ok(ArrowDataType::Date32) } - PrimitiveType::Timestamp => { - // Issue: https://github.com/delta-io/delta/issues/643 + PrimitiveType::Timestamp => Ok(ArrowDataType::Timestamp( + TimeUnit::Microsecond, + Some("UTC".into()), + )), + PrimitiveType::TimestampNtz => { Ok(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)) } } @@ -217,7 +220,7 @@ impl TryFrom<&ArrowDataType> for DataType { ArrowDataType::Date32 => Ok(DataType::Primitive(PrimitiveType::Date)), ArrowDataType::Date64 => Ok(DataType::Primitive(PrimitiveType::Date)), ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => { - Ok(DataType::Primitive(PrimitiveType::Timestamp)) + Ok(DataType::Primitive(PrimitiveType::TimestampNtz)) } ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.eq_ignore_ascii_case("utc") => @@ -770,6 +773,15 @@ mod tests { #[test] fn test_arrow_from_delta_timestamp_type() { let timestamp_field = DataType::Primitive(PrimitiveType::Timestamp); + assert_eq!( + >::try_from(×tamp_field).unwrap(), + ArrowDataType::Timestamp(TimeUnit::Microsecond, Some("UTC".to_string().into())) + ); + } + + #[test] + fn test_arrow_from_delta_timestampntz_type() { + let timestamp_field = DataType::Primitive(PrimitiveType::TimestampNtz); assert_eq!( >::try_from(×tamp_field).unwrap(), ArrowDataType::Timestamp(TimeUnit::Microsecond, None) @@ -777,11 +789,11 @@ mod tests { } #[test] - fn test_delta_from_arrow_timestamp_type() { + fn test_delta_from_arrow_timestamp_type_no_tz() { let timestamp_field = ArrowDataType::Timestamp(TimeUnit::Microsecond, None); assert_eq!( >::try_from(×tamp_field).unwrap(), - DataType::Primitive(PrimitiveType::Timestamp) + DataType::Primitive(PrimitiveType::TimestampNtz) ); } diff --git a/crates/core/src/kernel/expressions/eval.rs b/crates/core/src/kernel/expressions/eval.rs index 3796542ffc..cb6beea3ad 100644 --- a/crates/core/src/kernel/expressions/eval.rs +++ b/crates/core/src/kernel/expressions/eval.rs @@ -46,7 +46,10 @@ impl Scalar { Double(val) => Arc::new(Float64Array::from_value(*val, num_rows)), String(val) => Arc::new(StringArray::from(vec![val.clone(); num_rows])), Boolean(val) => Arc::new(BooleanArray::from(vec![*val; num_rows])), - Timestamp(val) => Arc::new(TimestampMicrosecondArray::from_value(*val, num_rows)), + Timestamp(val) => { + Arc::new(TimestampMicrosecondArray::from_value(*val, num_rows).with_timezone("UTC")) + } + TimestampNtz(val) => Arc::new(TimestampMicrosecondArray::from_value(*val, num_rows)), Date(val) => Arc::new(Date32Array::from_value(*val, num_rows)), Binary(val) => Arc::new(BinaryArray::from(vec![val.as_slice(); num_rows])), Decimal(val, precision, scale) => Arc::new( @@ -64,6 +67,9 @@ impl Scalar { PrimitiveType::String => Arc::new(StringArray::new_null(num_rows)), PrimitiveType::Boolean => Arc::new(BooleanArray::new_null(num_rows)), PrimitiveType::Timestamp => { + Arc::new(TimestampMicrosecondArray::new_null(num_rows).with_timezone("UTC")) + } + PrimitiveType::TimestampNtz => { Arc::new(TimestampMicrosecondArray::new_null(num_rows)) } PrimitiveType::Date => Arc::new(Date32Array::new_null(num_rows)), diff --git a/crates/core/src/kernel/expressions/scalars.rs b/crates/core/src/kernel/expressions/scalars.rs index 147c9d7633..d29cccb022 100644 --- a/crates/core/src/kernel/expressions/scalars.rs +++ b/crates/core/src/kernel/expressions/scalars.rs @@ -33,6 +33,8 @@ pub enum Scalar { Boolean(bool), /// Microsecond precision timestamp, adjusted to UTC. Timestamp(i64), + /// Microsecond precision timestamp, with no timezone. + TimestampNtz(i64), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date(i32), /// Binary data @@ -58,6 +60,7 @@ impl Scalar { Self::String(_) => DataType::Primitive(PrimitiveType::String), Self::Boolean(_) => DataType::Primitive(PrimitiveType::Boolean), Self::Timestamp(_) => DataType::Primitive(PrimitiveType::Timestamp), + Self::TimestampNtz(_) => DataType::Primitive(PrimitiveType::TimestampNtz), Self::Date(_) => DataType::Primitive(PrimitiveType::Date), Self::Binary(_) => DataType::Primitive(PrimitiveType::Binary), Self::Decimal(_, precision, scale) => DataType::decimal(*precision, *scale), @@ -88,7 +91,7 @@ impl Scalar { "false".to_string() } } - Self::Timestamp(ts) => { + Self::TimestampNtz(ts) | Self::Timestamp(ts) => { let ts = Utc.timestamp_micros(*ts).single().unwrap(); ts.format("%Y-%m-%d %H:%M:%S%.6f").to_string() } @@ -222,10 +225,17 @@ impl Scalar { .downcast_ref::() .map(|v| Self::Date(v.value(index))), // TODO handle timezones when implementing timestamp ntz feature. - Timestamp(TimeUnit::Microsecond, None) => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Timestamp(v.value(index))), + Timestamp(TimeUnit::Microsecond, tz) => match tz { + None => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Timestamp(v.value(index))), + Some(tz_str) if tz_str.as_ref() == "UTC" => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Timestamp(v.clone().with_timezone("UTC").value(index))), + _ => None, + }, Struct(fields) => { let struct_fields = fields .iter() @@ -283,6 +293,7 @@ impl PartialOrd for Scalar { (String(a), String(b)) => a.partial_cmp(b), (Boolean(a), Boolean(b)) => a.partial_cmp(b), (Timestamp(a), Timestamp(b)) => a.partial_cmp(b), + (TimestampNtz(a), TimestampNtz(b)) => a.partial_cmp(b), (Date(a), Date(b)) => a.partial_cmp(b), (Binary(a), Binary(b)) => a.partial_cmp(b), (Decimal(a, _, _), Decimal(b, _, _)) => a.partial_cmp(b), @@ -308,6 +319,7 @@ impl Display for Scalar { Self::String(s) => write!(f, "'{}'", s), Self::Boolean(b) => write!(f, "{}", b), Self::Timestamp(ts) => write!(f, "{}", ts), + Self::TimestampNtz(ts) => write!(f, "{}", ts), Self::Date(d) => write!(f, "{}", d), Self::Binary(b) => write!(f, "{:?}", b), Self::Decimal(value, _, scale) => match scale.cmp(&0) { @@ -433,6 +445,16 @@ impl PrimitiveType { .ok_or(self.parse_error(raw))?; Ok(Scalar::Timestamp(micros)) } + TimestampNtz => { + let timestamp = NaiveDateTime::parse_from_str(raw, "%Y-%m-%d %H:%M:%S%.f") + .map_err(|_| self.parse_error(raw))?; + let timestamp = Utc.from_utc_datetime(×tamp); + let micros = timestamp + .signed_duration_since(*UNIX_EPOCH) + .num_microseconds() + .ok_or(self.parse_error(raw))?; + Ok(Scalar::TimestampNtz(micros)) + } Binary => { let bytes = parse_escaped_binary_string(raw).map_err(|_| self.parse_error(raw))?; Ok(Scalar::Binary(bytes)) diff --git a/crates/core/src/kernel/models/actions.rs b/crates/core/src/kernel/models/actions.rs index 28eaa89cc4..f389102e52 100644 --- a/crates/core/src/kernel/models/actions.rs +++ b/crates/core/src/kernel/models/actions.rs @@ -175,7 +175,7 @@ pub enum ReaderFeatures { /// Deletion vectors for merge, update, delete DeletionVectors, /// timestamps without timezone support - #[serde(alias = "timestampNtz")] + #[serde(rename = "timestampNtz")] TimestampWithoutTimezone, /// version 2 of checkpointing V2Checkpoint, @@ -189,7 +189,9 @@ impl From<&parquet::record::Field> for ReaderFeatures { match value { parquet::record::Field::Str(feature) => match feature.as_str() { "columnMapping" => ReaderFeatures::ColumnMapping, - "deletionVectors" => ReaderFeatures::DeletionVectors, + "deletionVectors" | "delta.enableDeletionVectors" => { + ReaderFeatures::DeletionVectors + } "timestampNtz" => ReaderFeatures::TimestampWithoutTimezone, "v2Checkpoint" => ReaderFeatures::V2Checkpoint, f => ReaderFeatures::Other(f.to_string()), @@ -259,7 +261,7 @@ pub enum WriterFeatures { /// Row tracking on tables RowTracking, /// timestamps without timezone support - #[serde(alias = "timestampNtz")] + #[serde(rename = "timestampNtz")] TimestampWithoutTimezone, /// domain specific metadata DomainMetadata, @@ -281,15 +283,15 @@ impl From for WriterFeatures { impl From<&str> for WriterFeatures { fn from(value: &str) -> Self { match value { - "appendOnly" => WriterFeatures::AppendOnly, - "invariants" => WriterFeatures::Invariants, - "checkConstraints" => WriterFeatures::CheckConstraints, - "changeDataFeed" => WriterFeatures::ChangeDataFeed, + "appendOnly" | "delta.appendOnly" => WriterFeatures::AppendOnly, + "invariants" | "delta.invariants" => WriterFeatures::Invariants, + "checkConstraints" | "delta.checkConstraints" => WriterFeatures::CheckConstraints, + "changeDataFeed" | "delta.enableChangeDataFeed" => WriterFeatures::ChangeDataFeed, "generatedColumns" => WriterFeatures::GeneratedColumns, "columnMapping" => WriterFeatures::ColumnMapping, "identityColumns" => WriterFeatures::IdentityColumns, - "deletionVectors" => WriterFeatures::DeletionVectors, - "rowTracking" => WriterFeatures::RowTracking, + "deletionVectors" | "delta.enableDeletionVectors" => WriterFeatures::DeletionVectors, + "rowTracking" | "delta.enableRowTracking" => WriterFeatures::RowTracking, "timestampNtz" => WriterFeatures::TimestampWithoutTimezone, "domainMetadata" => WriterFeatures::DomainMetadata, "v2Checkpoint" => WriterFeatures::V2Checkpoint, diff --git a/crates/core/src/kernel/models/schema.rs b/crates/core/src/kernel/models/schema.rs index 87532d0b66..84e5967f12 100644 --- a/crates/core/src/kernel/models/schema.rs +++ b/crates/core/src/kernel/models/schema.rs @@ -134,6 +134,8 @@ pub struct StructField { impl Hash for StructField { fn hash(&self, state: &mut H) { self.name.hash(state); + self.data_type.hash(state); + self.nullable.hash(state); } } @@ -215,7 +217,7 @@ impl StructField { /// A struct is used to represent both the top-level schema of the table /// as well as struct columns that contain nested columns. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq, Hash)] pub struct StructType { #[serde(rename = "type")] /// The type of this struct @@ -379,7 +381,7 @@ impl<'a> IntoIterator for &'a StructType { } } -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq, Hash)] #[serde(rename_all = "camelCase")] /// An array stores a variable length collection of items of some type. pub struct ArrayType { @@ -415,7 +417,7 @@ impl ArrayType { } } -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq, Hash)] #[serde(rename_all = "camelCase")] /// A map stores an arbitrary length collection of key-value pairs pub struct MapType { @@ -465,8 +467,8 @@ fn default_true() -> bool { true } -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)] -#[serde(rename_all = "camelCase")] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq, Hash)] +#[serde(rename_all = "snake_case")] /// Primitive types supported by Delta pub enum PrimitiveType { /// UTF-8 encoded string of characters @@ -491,7 +493,9 @@ pub enum PrimitiveType { Date, /// Microsecond precision timestamp, adjusted to UTC. Timestamp, - // TODO: timestamp without timezone + /// Micrsoecond precision timestamp with no timezone + #[serde(alias = "timestampNtz")] + TimestampNtz, #[serde( serialize_with = "serialize_decimal", deserialize_with = "deserialize_decimal", @@ -552,6 +556,7 @@ impl Display for PrimitiveType { PrimitiveType::Binary => write!(f, "binary"), PrimitiveType::Date => write!(f, "date"), PrimitiveType::Timestamp => write!(f, "timestamp"), + PrimitiveType::TimestampNtz => write!(f, "timestampNtz"), PrimitiveType::Decimal(precision, scale) => { write!(f, "decimal({},{})", precision, scale) } @@ -559,7 +564,7 @@ impl Display for PrimitiveType { } } -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Eq, Hash)] #[serde(untagged, rename_all = "camelCase")] /// Top level delta tdatatypes pub enum DataType { @@ -606,6 +611,7 @@ impl DataType { pub const BINARY: Self = DataType::Primitive(PrimitiveType::Binary); pub const DATE: Self = DataType::Primitive(PrimitiveType::Date); pub const TIMESTAMP: Self = DataType::Primitive(PrimitiveType::Timestamp); + pub const TIMESTAMPNTZ: Self = DataType::Primitive(PrimitiveType::TimestampNtz); pub fn decimal(precision: u8, scale: i8) -> Self { DataType::Primitive(PrimitiveType::Decimal(precision, scale)) @@ -641,6 +647,7 @@ mod tests { use super::*; use serde_json; use serde_json::json; + use std::collections::hash_map::DefaultHasher; #[test] fn test_serde_data_types() { @@ -866,4 +873,65 @@ mod tests { let buf = r#"{"type":"struct","fields":[{"name":"ID_D_DATE","type":"long","nullable":true,"metadata":{"delta.identity.start":1,"delta.identity.step":1,"delta.identity.allowExplicitInsert":false}},{"name":"TXT_DateKey","type":"string","nullable":true,"metadata":{}}]}"#; let _schema: StructType = serde_json::from_str(buf).expect("Failed to load"); } + + fn get_hash(field: &StructField) -> u64 { + let mut hasher = DefaultHasher::new(); + field.hash(&mut hasher); + hasher.finish() + } + + #[test] + fn test_hash_struct_field() { + // different names should result in different hashes + let field_1 = StructField::new( + "field_name_1", + DataType::Primitive(PrimitiveType::Decimal(4, 4)), + true, + ); + let field_2 = StructField::new( + "field_name_2", + DataType::Primitive(PrimitiveType::Decimal(4, 4)), + true, + ); + assert_ne!(get_hash(&field_1), get_hash(&field_2)); + + // different types should result in different hashes + let field_int = StructField::new( + "field_name", + DataType::Primitive(PrimitiveType::Integer), + true, + ); + let field_string = StructField::new( + "field_name", + DataType::Primitive(PrimitiveType::String), + true, + ); + assert_ne!(get_hash(&field_int), get_hash(&field_string)); + + // different nullability should result in different hashes + let field_true = StructField::new( + "field_name", + DataType::Primitive(PrimitiveType::Binary), + true, + ); + let field_false = StructField::new( + "field_name", + DataType::Primitive(PrimitiveType::Binary), + false, + ); + assert_ne!(get_hash(&field_true), get_hash(&field_false)); + + // case where hashes are the same + let field_1 = StructField::new( + "field_name", + DataType::Primitive(PrimitiveType::Timestamp), + true, + ); + let field_2 = StructField::new( + "field_name", + DataType::Primitive(PrimitiveType::Timestamp), + true, + ); + assert_eq!(get_hash(&field_1), get_hash(&field_2)); + } } diff --git a/crates/core/src/kernel/snapshot/log_segment.rs b/crates/core/src/kernel/snapshot/log_segment.rs index 6ad1690db1..8a88c6a600 100644 --- a/crates/core/src/kernel/snapshot/log_segment.rs +++ b/crates/core/src/kernel/snapshot/log_segment.rs @@ -30,7 +30,7 @@ pub type CommitData = (Vec, DeltaOperation, Option Result { + if let Dictionary(_, value_type) = right.data_type() { + if value_type.equals_datatype(left.data_type()) { + return Ok(left.clone()); + } + } + if let Dictionary(_, value_type) = left.data_type() { + if value_type.equals_datatype(right.data_type()) { + return Ok(right.clone()); + } + } + let mut new_field = left.clone(); + new_field.try_merge(right)?; + Ok(new_field) +} + +pub(crate) fn merge_schema( + left: ArrowSchema, + right: ArrowSchema, +) -> Result { + let mut errors = Vec::with_capacity(left.fields().len()); + let merged_fields: Result, ArrowError> = left + .fields() + .iter() + .map(|field| { + let right_field = right.field_with_name(field.name()); + if let Ok(right_field) = right_field { + let field_or_not = merge_field(field.as_ref(), right_field); + match field_or_not { + Err(e) => { + errors.push(e.to_string()); + Err(e) + } + Ok(f) => Ok(f), + } + } else { + Ok(field.as_ref().clone()) + } + }) + .collect(); + match merged_fields { + Ok(mut fields) => { + for field in right.fields() { + if !left.field_with_name(field.name()).is_ok() { + fields.push(field.as_ref().clone()); + } + } + + Ok(ArrowSchema::new(fields)) + } + Err(e) => { + errors.push(e.to_string()); + Err(ArrowError::SchemaError(errors.join("\n"))) + } + } +} + fn cast_struct( struct_array: &StructArray, fields: &Fields, cast_options: &CastOptions, + add_missing: bool, ) -> Result>, arrow_schema::ArrowError> { fields .iter() .map(|field| { - let col = struct_array.column_by_name(field.name()).unwrap(); - if let (DataType::Struct(_), DataType::Struct(child_fields)) = - (col.data_type(), field.data_type()) - { - let child_struct = StructArray::from(col.into_data()); - let s = cast_struct(&child_struct, child_fields, cast_options)?; - Ok(Arc::new(StructArray::new( - child_fields.clone(), - s, - child_struct.nulls().map(ToOwned::to_owned), - )) as ArrayRef) - } else if is_cast_required(col.data_type(), field.data_type()) { - cast_with_options(col, field.data_type(), cast_options) - } else { - Ok(col.clone()) + let col_or_not = struct_array.column_by_name(field.name()); + match col_or_not { + None => match add_missing { + true => Ok(new_null_array(field.data_type(), struct_array.len())), + false => Err(arrow_schema::ArrowError::SchemaError(format!( + "Could not find column {0}", + field.name() + ))), + }, + Some(col) => { + if let (DataType::Struct(_), DataType::Struct(child_fields)) = + (col.data_type(), field.data_type()) + { + let child_struct = StructArray::from(col.into_data()); + let s = + cast_struct(&child_struct, child_fields, cast_options, add_missing)?; + Ok(Arc::new(StructArray::new( + child_fields.clone(), + s, + child_struct.nulls().map(ToOwned::to_owned), + )) as ArrayRef) + } else if is_cast_required(col.data_type(), field.data_type()) { + cast_with_options(col, field.data_type(), cast_options) + } else { + Ok(col.clone()) + } + } } }) .collect::, _>>() @@ -51,6 +124,7 @@ pub fn cast_record_batch( batch: &RecordBatch, target_schema: ArrowSchemaRef, safe: bool, + add_missing: bool, ) -> DeltaResult { let cast_options = CastOptions { safe, @@ -62,8 +136,7 @@ pub fn cast_record_batch( batch.columns().to_owned(), None, ); - - let columns = cast_struct(&s, target_schema.fields(), &cast_options)?; + let columns = cast_struct(&s, target_schema.fields(), &cast_options, add_missing)?; Ok(RecordBatch::try_new(target_schema, columns)?) } @@ -93,7 +166,7 @@ mod tests { )]); let target_schema = Arc::new(Schema::new(fields)) as SchemaRef; - let result = cast_record_batch(&record_batch, target_schema, false); + let result = cast_record_batch(&record_batch, target_schema, false, false); let schema = result.unwrap().schema(); let field = schema.column_with_name("list_column").unwrap().1; diff --git a/crates/core/src/operations/constraints.rs b/crates/core/src/operations/constraints.rs index 9bf5f2d22c..aaa0d62d3a 100644 --- a/crates/core/src/operations/constraints.rs +++ b/crates/core/src/operations/constraints.rs @@ -17,7 +17,7 @@ use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::{ register_store, DeltaDataChecker, DeltaScanBuilder, DeltaSessionContext, }; -use crate::kernel::{CommitInfo, IsolationLevel, Protocol}; +use crate::kernel::{CommitInfo, IsolationLevel, Protocol, WriterFeatures}; use crate::logstore::LogStoreRef; use crate::operations::datafusion_utils::Expression; use crate::operations::transaction::commit; @@ -177,7 +177,17 @@ impl std::future::IntoFuture for ConstraintBuilder { 3 }, reader_features: old_protocol.reader_features.clone(), - writer_features: old_protocol.writer_features.clone(), + writer_features: if old_protocol.min_writer_version < 7 { + old_protocol.writer_features.clone() + } else { + let current_features = old_protocol.writer_features.clone(); + if let Some(mut features) = current_features { + features.insert(WriterFeatures::CheckConstraints); + Some(features) + } else { + current_features + } + }, }; let operational_parameters = HashMap::from_iter([ @@ -190,10 +200,7 @@ impl std::future::IntoFuture for ConstraintBuilder { expr: expr_str.clone(), }; - let app_metadata = match this.app_metadata { - Some(metadata) => metadata, - None => HashMap::default(), - }; + let app_metadata = this.app_metadata.unwrap_or_default(); let commit_info = CommitInfo { timestamp: Some(Utc::now().timestamp_millis()), diff --git a/crates/core/src/operations/convert_to_delta.rs b/crates/core/src/operations/convert_to_delta.rs index 1ed4e1cee6..351a596062 100644 --- a/crates/core/src/operations/convert_to_delta.rs +++ b/crates/core/src/operations/convert_to_delta.rs @@ -226,7 +226,7 @@ impl ConvertToDeltaBuilder { } /// Consume self into CreateBuilder with corresponding add actions, schemas and operation meta - async fn into_create_builder(mut self) -> Result { + async fn into_create_builder(self) -> Result { // Use the specified log store. If a log store is not provided, create a new store from the specified path. // Return an error if neither log store nor path is provided let log_store = if let Some(log_store) = self.log_store { @@ -270,6 +270,13 @@ impl ConvertToDeltaBuilder { // Iterate over the parquet files. Parse partition columns, generate add actions and collect parquet file schemas let mut arrow_schemas = Vec::new(); let mut actions = Vec::new(); + // partition columns that were defined by caller and are expected to apply on this table + let mut expected_partitions: HashMap = self + .partition_schema + .clone() + .into_iter() + .map(|field| (field.name.clone(), field)) + .collect(); // A HashSet of all unique partition columns in a Parquet table let mut partition_columns = HashSet::new(); // A vector of StructField of all unique partition columns in a Parquet table @@ -290,7 +297,7 @@ impl ConvertToDeltaBuilder { .ok_or(Error::MissingPartitionSchema)?; if partition_columns.insert(key.to_string()) { - if let Some(schema) = self.partition_schema.take(key) { + if let Some(schema) = expected_partitions.remove(key) { partition_schema_fields.insert(key.to_string(), schema); } else { // Return an error if the schema of a partition column is not provided by user @@ -360,7 +367,7 @@ impl ConvertToDeltaBuilder { arrow_schemas.push(arrow_schema); } - if !self.partition_schema.is_empty() { + if !expected_partitions.is_empty() { // Partition column provided by the user does not exist in the parquet files return Err(Error::PartitionColumnNotExist(self.partition_schema)); } diff --git a/crates/core/src/operations/create.rs b/crates/core/src/operations/create.rs index bbf11e3705..e754674ccb 100644 --- a/crates/core/src/operations/create.rs +++ b/crates/core/src/operations/create.rs @@ -1,7 +1,7 @@ //! Command for creating a new delta table // https://github.com/delta-io/delta/blob/master/core/src/main/scala/org/apache/spark/sql/delta/commands/CreateDeltaTableCommand.scala -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use futures::future::BoxFuture; @@ -9,7 +9,9 @@ use serde_json::Value; use super::transaction::{commit, PROTOCOL}; use crate::errors::{DeltaResult, DeltaTableError}; -use crate::kernel::{Action, DataType, Metadata, Protocol, StructField, StructType}; +use crate::kernel::{ + Action, DataType, Metadata, Protocol, ReaderFeatures, StructField, StructType, WriterFeatures, +}; use crate::logstore::{LogStore, LogStoreRef}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::table::builder::ensure_table_uri; @@ -233,8 +235,45 @@ impl CreateBuilder { ) }; + let contains_timestampntz = &self + .columns + .iter() + .any(|f| f.data_type() == &DataType::TIMESTAMPNTZ); + // TODO configure more permissive versions based on configuration. Also how should this ideally be handled? // We set the lowest protocol we can, and if subsequent writes use newer features we update metadata? + + let (min_reader_version, min_writer_version, writer_features, reader_features) = + if *contains_timestampntz { + let mut converted_writer_features = self + .configuration + .keys() + .map(|key| key.clone().into()) + .filter(|v| !matches!(v, WriterFeatures::Other(_))) + .collect::>(); + + let mut converted_reader_features = self + .configuration + .keys() + .map(|key| key.clone().into()) + .filter(|v| !matches!(v, ReaderFeatures::Other(_))) + .collect::>(); + converted_writer_features.insert(WriterFeatures::TimestampWithoutTimezone); + converted_reader_features.insert(ReaderFeatures::TimestampWithoutTimezone); + ( + 3, + 7, + Some(converted_writer_features), + Some(converted_reader_features), + ) + } else { + ( + PROTOCOL.default_reader_version(), + PROTOCOL.default_writer_version(), + None, + None, + ) + }; let protocol = self .actions .iter() @@ -244,10 +283,10 @@ impl CreateBuilder { _ => unreachable!(), }) .unwrap_or_else(|| Protocol { - min_reader_version: PROTOCOL.default_reader_version(), - min_writer_version: PROTOCOL.default_writer_version(), - writer_features: None, - reader_features: None, + min_reader_version, + min_writer_version, + writer_features, + reader_features, }); let mut metadata = Metadata::try_new( diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index 2e3e99bde2..072f4cd875 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -17,6 +17,7 @@ //! .await?; //! ```` +use core::panic; use std::collections::HashMap; use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; @@ -167,9 +168,15 @@ async fn excute_non_empty_expr( None, writer_properties, false, - false, + None, ) - .await?; + .await? + .into_iter() + .map(|a| match a { + Action::Add(a) => a, + _ => panic!("Expected Add action"), + }) + .collect::>(); let read_records = scan.parquet_scan.metrics().and_then(|m| m.output_rows()); let filter_records = filter.metrics().and_then(|m| m.output_rows()); @@ -194,7 +201,7 @@ async fn execute( let scan_start = Instant::now(); let candidates = find_files(snapshot, log_store.clone(), &state, predicate.clone()).await?; - metrics.scan_time_ms = Instant::now().duration_since(scan_start).as_micros(); + metrics.scan_time_ms = Instant::now().duration_since(scan_start).as_millis(); let predicate = predicate.unwrap_or(Expr::Literal(ScalarValue::Boolean(Some(true)))); @@ -242,7 +249,7 @@ async fn execute( })) } - metrics.execution_time_ms = Instant::now().duration_since(exec_start).as_micros(); + metrics.execution_time_ms = Instant::now().duration_since(exec_start).as_millis(); let mut app_metadata = match app_metadata { Some(meta) => meta, diff --git a/crates/core/src/operations/drop_constraints.rs b/crates/core/src/operations/drop_constraints.rs new file mode 100644 index 0000000000..417f73d8e8 --- /dev/null +++ b/crates/core/src/operations/drop_constraints.rs @@ -0,0 +1,205 @@ +//! Drop a constraint from a table + +use std::collections::HashMap; + +use chrono::Utc; +use futures::future::BoxFuture; +use serde_json::json; + +use crate::kernel::{Action, CommitInfo, IsolationLevel}; +use crate::logstore::LogStoreRef; +use crate::operations::transaction::commit; +use crate::protocol::DeltaOperation; +use crate::table::state::DeltaTableState; +use crate::DeltaTable; +use crate::{DeltaResult, DeltaTableError}; + +/// Remove constraints from the table +pub struct DropConstraintBuilder { + /// A snapshot of the table's state + snapshot: DeltaTableState, + /// Name of the constraint + name: Option, + /// Raise if constraint doesn't exist + raise_if_not_exists: bool, + /// Delta object store for handling data files + log_store: LogStoreRef, + /// Additional metadata to be added to commit + app_metadata: Option>, +} + +impl DropConstraintBuilder { + /// Create a new builder + pub fn new(log_store: LogStoreRef, snapshot: DeltaTableState) -> Self { + Self { + name: None, + raise_if_not_exists: true, + snapshot, + log_store, + app_metadata: None, + } + } + + /// Specify the constraint to be removed + pub fn with_constraint>(mut self, name: S) -> Self { + self.name = Some(name.into()); + self + } + + /// Specify if you want to raise if the constraint does not exist + pub fn with_raise_if_not_exists(mut self, raise: bool) -> Self { + self.raise_if_not_exists = raise; + self + } + + /// Additional metadata to be added to commit info + pub fn with_metadata( + mut self, + metadata: impl IntoIterator, + ) -> Self { + self.app_metadata = Some(HashMap::from_iter(metadata)); + self + } +} + +impl std::future::IntoFuture for DropConstraintBuilder { + type Output = DeltaResult; + + type IntoFuture = BoxFuture<'static, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + let mut this = self; + + Box::pin(async move { + let name = this + .name + .ok_or(DeltaTableError::Generic("No name provided".to_string()))?; + + let mut metadata = this.snapshot.metadata().clone(); + let configuration_key = format!("delta.constraints.{}", name); + + if metadata.configuration.remove(&configuration_key).is_none() { + if this.raise_if_not_exists { + return Err(DeltaTableError::Generic(format!( + "Constraint with name: {} doesn't exists", + name + ))); + } + return Ok(DeltaTable::new_with_state(this.log_store, this.snapshot)); + } + let operational_parameters = HashMap::from_iter([("name".to_string(), json!(&name))]); + + let operations = DeltaOperation::DropConstraint { name: name.clone() }; + + let app_metadata = this.app_metadata.unwrap_or_default(); + + let commit_info = CommitInfo { + timestamp: Some(Utc::now().timestamp_millis()), + operation: Some(operations.name().to_string()), + operation_parameters: Some(operational_parameters), + read_version: Some(this.snapshot.version()), + isolation_level: Some(IsolationLevel::Serializable), + is_blind_append: Some(false), + info: app_metadata, + ..Default::default() + }; + + let actions = vec![Action::CommitInfo(commit_info), Action::Metadata(metadata)]; + + let version = commit( + this.log_store.as_ref(), + &actions, + operations.clone(), + Some(&this.snapshot), + None, + ) + .await?; + + this.snapshot.merge(actions, &operations, version)?; + Ok(DeltaTable::new_with_state(this.log_store, this.snapshot)) + }) + } +} + +#[cfg(feature = "datafusion")] +#[cfg(test)] +mod tests { + use crate::writer::test_utils::{create_bare_table, get_record_batch}; + use crate::{DeltaOps, DeltaResult, DeltaTable}; + + async fn get_constraint_op_params(table: &mut DeltaTable) -> String { + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[0]; + + last_commit + .operation_parameters + .as_ref() + .unwrap() + .get("name") + .unwrap() + .as_str() + .unwrap() + .to_owned() + } + + #[tokio::test] + async fn drop_valid_constraint() -> DeltaResult<()> { + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + let table = DeltaOps(write); + + let table = table + .add_constraint() + .with_constraint("id", "value < 1000") + .await?; + + let mut table = DeltaOps(table) + .drop_constraints() + .with_constraint("id") + .await?; + + let expected_name = "id"; + assert_eq!(get_constraint_op_params(&mut table).await, expected_name); + assert_eq!(table.metadata().unwrap().configuration.get("id"), None); + Ok(()) + } + + #[tokio::test] + async fn drop_invalid_constraint_not_existing() -> DeltaResult<()> { + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + + let table = DeltaOps(write) + .drop_constraints() + .with_constraint("not_existing") + .await; + assert!(table.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn drop_invalid_constraint_ignore() -> DeltaResult<()> { + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + + let version = write.version(); + + let table = DeltaOps(write) + .drop_constraints() + .with_constraint("not_existing") + .with_raise_if_not_exists(false) + .await?; + + let version_after = table.version(); + + assert_eq!(version, version_after); + Ok(()) + } +} diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 636ed6f66d..96c06fe548 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -1184,7 +1184,7 @@ async fn execute( let projection = join.with_column(OPERATION_COLUMN, case)?; - let mut new_columns = projection; + let mut new_columns = vec![]; let mut write_projection = Vec::new(); for delta_field in snapshot.schema().fields() { @@ -1223,7 +1223,7 @@ async fn execute( Expr::Column(Column::from_qualified_name_ignore_case(name.clone())) .alias(delta_field.name()), ); - new_columns = new_columns.with_column(&name, case)?; + new_columns.push((name, case)); } let mut insert_when = Vec::with_capacity(ops.len()); @@ -1299,18 +1299,40 @@ async fn execute( .end() } - new_columns = new_columns.with_column(DELETE_COLUMN, build_case(delete_when, delete_then)?)?; - new_columns = - new_columns.with_column(TARGET_INSERT_COLUMN, build_case(insert_when, insert_then)?)?; - new_columns = - new_columns.with_column(TARGET_UPDATE_COLUMN, build_case(update_when, update_then)?)?; - new_columns = new_columns.with_column( - TARGET_DELETE_COLUMN, + new_columns.push(( + DELETE_COLUMN.to_owned(), + build_case(delete_when, delete_then)?, + )); + new_columns.push(( + TARGET_INSERT_COLUMN.to_owned(), + build_case(insert_when, insert_then)?, + )); + new_columns.push(( + TARGET_UPDATE_COLUMN.to_owned(), + build_case(update_when, update_then)?, + )); + new_columns.push(( + TARGET_DELETE_COLUMN.to_owned(), build_case(target_delete_when, target_delete_then)?, - )?; - new_columns = new_columns.with_column(TARGET_COPY_COLUMN, build_case(copy_when, copy_then)?)?; - - let new_columns = new_columns.into_unoptimized_plan(); + )); + new_columns.push(( + TARGET_COPY_COLUMN.to_owned(), + build_case(copy_when, copy_then)?, + )); + + let new_columns = { + let plan = projection.into_unoptimized_plan(); + let mut fields: Vec = plan + .schema() + .fields() + .iter() + .map(|f| col(f.qualified_column())) + .collect(); + + fields.extend(new_columns.into_iter().map(|(name, ex)| ex.alias(name))); + + LogicalPlanBuilder::from(plan).project(fields)?.build()? + }; let distrbute_expr = col(file_column.as_str()); @@ -1357,13 +1379,13 @@ async fn execute( None, writer_properties, safe_cast, - false, + None, ) .await?; metrics.rewrite_time_ms = Instant::now().duration_since(rewrite_start).as_millis() as u64; - let mut actions: Vec = add_actions.into_iter().map(Action::Add).collect(); + let mut actions: Vec = add_actions.clone(); metrics.num_target_files_added = actions.len(); let survivors = barrier @@ -1552,6 +1574,7 @@ mod tests { table } + // TODO(ion): property keys are not passed through or translated as table features.. fix this as well #[tokio::test] async fn test_merge_when_delta_table_is_append_only() { let schema = get_arrow_schema(&None); diff --git a/crates/core/src/operations/mod.rs b/crates/core/src/operations/mod.rs index 2271f36641..666b2dc66a 100644 --- a/crates/core/src/operations/mod.rs +++ b/crates/core/src/operations/mod.rs @@ -18,6 +18,7 @@ use std::collections::HashMap; pub mod cast; pub mod convert_to_delta; pub mod create; +pub mod drop_constraints; pub mod filesystem_check; pub mod optimize; pub mod restore; @@ -27,7 +28,8 @@ pub mod vacuum; #[cfg(feature = "datafusion")] use self::{ constraints::ConstraintBuilder, datafusion_utils::Expression, delete::DeleteBuilder, - load::LoadBuilder, merge::MergeBuilder, update::UpdateBuilder, write::WriteBuilder, + drop_constraints::DropConstraintBuilder, load::LoadBuilder, merge::MergeBuilder, + update::UpdateBuilder, write::WriteBuilder, }; #[cfg(feature = "datafusion")] pub use ::datafusion::physical_plan::common::collect as collect_sendable_stream; @@ -199,6 +201,13 @@ impl DeltaOps { pub fn add_constraint(self) -> ConstraintBuilder { ConstraintBuilder::new(self.0.log_store, self.0.state.unwrap()) } + + /// Drops constraints from a table + #[cfg(feature = "datafusion")] + #[must_use] + pub fn drop_constraints(self) -> DropConstraintBuilder { + DropConstraintBuilder::new(self.0.log_store, self.0.state.unwrap()) + } } impl From for DeltaOps { diff --git a/crates/core/src/operations/optimize.rs b/crates/core/src/operations/optimize.rs index 990997399e..90334e6de1 100644 --- a/crates/core/src/operations/optimize.rs +++ b/crates/core/src/operations/optimize.rs @@ -457,8 +457,12 @@ impl MergePlan { while let Some(maybe_batch) = read_stream.next().await { let mut batch = maybe_batch?; - batch = - super::cast::cast_record_batch(&batch, task_parameters.file_schema.clone(), false)?; + batch = super::cast::cast_record_batch( + &batch, + task_parameters.file_schema.clone(), + false, + false, + )?; partial_metrics.num_batches += 1; writer.write(&batch).await.map_err(DeltaTableError::from)?; } diff --git a/crates/core/src/operations/transaction/mod.rs b/crates/core/src/operations/transaction/mod.rs index 63d1789e0a..12d548b85f 100644 --- a/crates/core/src/operations/transaction/mod.rs +++ b/crates/core/src/operations/transaction/mod.rs @@ -13,6 +13,7 @@ use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, CommitInfo, ReaderFeatures, WriterFeatures}; use crate::logstore::LogStore; use crate::protocol::DeltaOperation; +use crate::storage::ObjectStoreRetryExt; use crate::table::state::DeltaTableState; pub use self::protocol::INSTANCE as PROTOCOL; @@ -71,12 +72,12 @@ pub enum TransactionError { UnsupportedWriterFeatures(Vec), /// Error returned when writer features are required but not specified - #[error("Writer features must be specified for writerversion >= 7")] - WriterFeaturesRequired, + #[error("Writer features must be specified for writerversion >= 7, please specify: {0:?}")] + WriterFeaturesRequired(WriterFeatures), /// Error returned when reader features are required but not specified - #[error("Reader features must be specified for reader version >= 3")] - ReaderFeaturesRequired, + #[error("Reader features must be specified for reader version >= 3, please specify: {0:?}")] + ReaderFeaturesRequired(ReaderFeatures), /// The transaction failed to commit due to an error in an implementation-specific layer. /// Currently used by DynamoDb-backed S3 log store when database operations fail. @@ -242,13 +243,19 @@ pub async fn commit_with_retries( attempt_number += 1; } Err(err) => { - log_store.object_store().delete(&tmp_commit).await?; + log_store + .object_store() + .delete_with_retries(&tmp_commit, 15) + .await?; return Err(TransactionError::CommitConflict(err).into()); } }; } Err(err) => { - log_store.object_store().delete(&tmp_commit).await?; + log_store + .object_store() + .delete_with_retries(&tmp_commit, 15) + .await?; return Err(err.into()); } } diff --git a/crates/core/src/operations/transaction/protocol.rs b/crates/core/src/operations/transaction/protocol.rs index 07a7b75405..d9ff45de5f 100644 --- a/crates/core/src/operations/transaction/protocol.rs +++ b/crates/core/src/operations/transaction/protocol.rs @@ -4,7 +4,7 @@ use lazy_static::lazy_static; use once_cell::sync::Lazy; use super::TransactionError; -use crate::kernel::{Action, ReaderFeatures, WriterFeatures}; +use crate::kernel::{Action, DataType, ReaderFeatures, Schema, WriterFeatures}; use crate::table::state::DeltaTableState; lazy_static! { @@ -76,6 +76,39 @@ impl ProtocolChecker { Ok(()) } + /// Check can write_timestamp_ntz + pub fn check_can_write_timestamp_ntz( + &self, + snapshot: &DeltaTableState, + schema: &Schema, + ) -> Result<(), TransactionError> { + let contains_timestampntz = schema + .fields() + .iter() + .any(|f| f.data_type() == &DataType::TIMESTAMPNTZ); + + let required_features: Option<&HashSet> = + match snapshot.protocol().min_writer_version { + 0..=6 => None, + _ => snapshot.protocol().writer_features.as_ref(), + }; + + if let Some(table_features) = required_features { + if !table_features.contains(&WriterFeatures::TimestampWithoutTimezone) + && contains_timestampntz + { + return Err(TransactionError::WriterFeaturesRequired( + WriterFeatures::TimestampWithoutTimezone, + )); + } + } else if contains_timestampntz { + return Err(TransactionError::WriterFeaturesRequired( + WriterFeatures::TimestampWithoutTimezone, + )); + } + Ok(()) + } + /// Check if delta-rs can read form the given delta table. pub fn can_read_from(&self, snapshot: &DeltaTableState) -> Result<(), TransactionError> { let required_features: Option<&HashSet> = @@ -139,7 +172,9 @@ impl ProtocolChecker { .protocol() .writer_features .as_ref() - .ok_or(TransactionError::WriterFeaturesRequired)? + .ok_or(TransactionError::WriterFeaturesRequired( + WriterFeatures::AppendOnly, + ))? .contains(&WriterFeatures::AppendOnly) && snapshot.table_config().append_only() }; @@ -164,11 +199,13 @@ impl ProtocolChecker { /// As we implement new features, we need to update this instance accordingly. /// resulting version support is determined by the supported table feature set. pub static INSTANCE: Lazy = Lazy::new(|| { - let reader_features = HashSet::new(); + let mut reader_features = HashSet::new(); + reader_features.insert(ReaderFeatures::TimestampWithoutTimezone); // reader_features.insert(ReaderFeatures::ColumnMapping); let mut writer_features = HashSet::new(); writer_features.insert(WriterFeatures::AppendOnly); + writer_features.insert(WriterFeatures::TimestampWithoutTimezone); #[cfg(feature = "datafusion")] { writer_features.insert(WriterFeatures::Invariants); diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index d07f3f9fc0..803b1d0312 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -357,7 +357,7 @@ async fn execute( None, writer_properties, safe_cast, - false, + None, ) .await?; @@ -377,7 +377,7 @@ async fn execute( .duration_since(UNIX_EPOCH) .unwrap() .as_millis() as i64; - let mut actions: Vec = add_actions.into_iter().map(Action::Add).collect(); + let mut actions: Vec = add_actions.clone(); metrics.num_added_files = actions.len(); metrics.num_removed_files = candidates.candidates.len(); diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 73c1599a7e..d80ddb8fa5 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -25,12 +25,14 @@ //! ```` use std::collections::HashMap; +use std::str::FromStr; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; +use std::vec; use arrow_array::RecordBatch; use arrow_cast::can_cast_types; -use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef}; +use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::physical_expr::create_physical_expr; use datafusion::physical_plan::filter::FilterExec; @@ -50,8 +52,9 @@ use crate::delta_datafusion::expr::parse_predicate_expression; use crate::delta_datafusion::DeltaDataChecker; use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::errors::{DeltaResult, DeltaTableError}; -use crate::kernel::{Action, Add, PartitionsExt, Remove, StructType}; +use crate::kernel::{Action, Add, Metadata, PartitionsExt, Remove, StructType}; use crate::logstore::LogStoreRef; +use crate::operations::cast::{cast_record_batch, merge_schema}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; @@ -87,6 +90,30 @@ impl From for DeltaTableError { } } +///Specifies how to handle schema drifts +#[derive(PartialEq, Clone, Copy)] +pub enum SchemaMode { + /// Overwrite the schema with the new schema + Overwrite, + /// Append the new schema to the existing schema + Merge, +} + +impl FromStr for SchemaMode { + type Err = DeltaTableError; + + fn from_str(s: &str) -> DeltaResult { + match s.to_ascii_lowercase().as_str() { + "overwrite" => Ok(SchemaMode::Overwrite), + "merge" => Ok(SchemaMode::Merge), + _ => Err(DeltaTableError::Generic(format!( + "Invalid schema write mode provided: {}, only these are supported: ['overwrite', 'merge']", + s + ))), + } + } +} + /// Write data into a DeltaTable pub struct WriteBuilder { /// A snapshot of the to-be-loaded table's state @@ -109,8 +136,8 @@ pub struct WriteBuilder { write_batch_size: Option, /// RecordBatches to be written into the table batches: Option>, - /// whether to overwrite the schema - overwrite_schema: bool, + /// whether to overwrite the schema or to merge it. None means to fail on schmema drift + schema_mode: Option, /// how to handle cast failures, either return NULL (safe=true) or return ERR (safe=false) safe_cast: bool, /// Parquet writer properties @@ -140,7 +167,7 @@ impl WriteBuilder { write_batch_size: None, batches: None, safe_cast: false, - overwrite_schema: false, + schema_mode: None, writer_properties: None, app_metadata: None, name: None, @@ -155,9 +182,9 @@ impl WriteBuilder { self } - /// Add overwrite_schema - pub fn with_overwrite_schema(mut self, overwrite_schema: bool) -> Self { - self.overwrite_schema = overwrite_schema; + /// Add Schema Write Mode + pub fn with_schema_mode(mut self, schema_mode: SchemaMode) -> Self { + self.schema_mode = Some(schema_mode); self } @@ -258,6 +285,18 @@ impl WriteBuilder { match &self.snapshot { Some(snapshot) => { PROTOCOL.can_write_to(snapshot)?; + + if let Some(plan) = &self.input { + let schema: StructType = (plan.schema()).try_into()?; + PROTOCOL.check_can_write_timestamp_ntz(snapshot, &schema)?; + } else if let Some(batches) = &self.batches { + if batches.is_empty() { + return Err(WriteError::MissingData.into()); + } + let schema: StructType = (batches[0].schema()).try_into()?; + PROTOCOL.check_can_write_timestamp_ntz(snapshot, &schema)?; + } + match self.mode { SaveMode::ErrorIfExists => { Err(WriteError::AlreadyExists(self.log_store.root_uri()).into()) @@ -311,10 +350,9 @@ async fn write_execution_plan_with_predicate( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - overwrite_schema: bool, -) -> DeltaResult> { - // Use input schema to prevent wrapping partitions columns into a dictionary. - let schema: ArrowSchemaRef = if overwrite_schema { + schema_mode: Option, +) -> DeltaResult> { + let schema: ArrowSchemaRef = if schema_mode.is_some() { plan.schema() } else { snapshot @@ -352,23 +390,29 @@ async fn write_execution_plan_with_predicate( let mut writer = DeltaWriter::new(object_store.clone(), config); let checker_stream = checker.clone(); let mut stream = inner_plan.execute(i, task_ctx)?; - let handle: tokio::task::JoinHandle>> = + let handle: tokio::task::JoinHandle>> = tokio::task::spawn(async move { while let Some(maybe_batch) = stream.next().await { let batch = maybe_batch?; checker_stream.check_batch(&batch).await?; - let arr = - super::cast::cast_record_batch(&batch, inner_schema.clone(), safe_cast)?; + let arr = super::cast::cast_record_batch( + &batch, + inner_schema.clone(), + safe_cast, + schema_mode == Some(SchemaMode::Merge), + )?; writer.write(&arr).await?; } - writer.close().await + let add_actions = writer.close().await; + match add_actions { + Ok(actions) => Ok(actions.into_iter().map(Action::Add).collect::>()), + Err(err) => Err(err), + } }); tasks.push(handle); } - - // Collect add actions to add to commit - Ok(futures::future::join_all(tasks) + let actions = futures::future::join_all(tasks) .await .into_iter() .collect::, _>>() @@ -377,7 +421,9 @@ async fn write_execution_plan_with_predicate( .collect::, _>>()? .concat() .into_iter() - .collect::>()) + .collect::>(); + // Collect add actions to add to commit + Ok(actions) } #[allow(clippy::too_many_arguments)] @@ -391,8 +437,8 @@ pub(crate) async fn write_execution_plan( write_batch_size: Option, writer_properties: Option, safe_cast: bool, - overwrite_schema: bool, -) -> DeltaResult> { + schema_mode: Option, +) -> DeltaResult> { write_execution_plan_with_predicate( None, snapshot, @@ -404,7 +450,7 @@ pub(crate) async fn write_execution_plan( write_batch_size, writer_properties, safe_cast, - overwrite_schema, + schema_mode, ) .await } @@ -417,7 +463,7 @@ async fn execute_non_empty_expr( expression: &Expr, rewrite: &[Add], writer_properties: Option, -) -> DeltaResult> { +) -> DeltaResult> { // For each identified file perform a parquet scan + filter + limit (1) + count. // If returned count is not zero then append the file to be rewritten and removed from the log. Otherwise do nothing to the file. @@ -452,7 +498,7 @@ async fn execute_non_empty_expr( None, writer_properties, false, - false, + None, ) .await?; @@ -488,7 +534,7 @@ async fn prepare_predicate_actions( }; let remove = candidates.candidates; - let mut actions: Vec = add.into_iter().map(Action::Add).collect(); + let mut actions: Vec = add.into_iter().collect(); for action in remove { actions.push(Action::Remove(Remove { @@ -520,6 +566,11 @@ impl std::future::IntoFuture for WriteBuilder { PROTOCOL.check_append_only(snapshot)?; } } + if this.schema_mode == Some(SchemaMode::Overwrite) && this.mode != SaveMode::Overwrite { + return Err(DeltaTableError::Generic( + "Schema overwrite not supported for Append".to_string(), + )); + } // Create table actions to initialize table in case it does not yet exist and should be created let mut actions = this.check_preconditions().await?; @@ -546,8 +597,13 @@ impl std::future::IntoFuture for WriteBuilder { } else { Ok(this.partition_columns.unwrap_or_default()) }?; - + let mut schema_drift = false; let plan = if let Some(plan) = this.input { + if this.schema_mode == Some(SchemaMode::Merge) { + return Err(DeltaTableError::Generic( + "Schema merge not supported yet for Datafusion".to_string(), + )); + } Ok(plan) } else if let Some(batches) = this.batches { if batches.is_empty() { @@ -555,6 +611,7 @@ impl std::future::IntoFuture for WriteBuilder { } else { let schema = batches[0].schema(); + let mut new_schema = None; if let Some(snapshot) = &this.snapshot { let table_schema = snapshot .physical_arrow_schema(this.log_store.object_store().clone()) @@ -562,23 +619,38 @@ impl std::future::IntoFuture for WriteBuilder { .or_else(|_| snapshot.arrow_schema()) .unwrap_or(schema.clone()); - if !can_cast_batch(schema.fields(), table_schema.fields()) - && !(this.overwrite_schema && matches!(this.mode, SaveMode::Overwrite)) + if let Err(schema_err) = + try_cast_batch(schema.fields(), table_schema.fields()) { - return Err(DeltaTableError::Generic( - "Schema of data does not match table schema".to_string(), - )); - }; + schema_drift = true; + if this.mode == SaveMode::Overwrite && this.schema_mode.is_some() { + new_schema = None // we overwrite anyway, so no need to cast + } else if this.schema_mode == Some(SchemaMode::Merge) { + new_schema = Some(Arc::new(merge_schema( + table_schema.as_ref().clone(), + schema.as_ref().clone(), + )?)); + } else { + return Err(schema_err.into()); + } + } } let data = if !partition_columns.is_empty() { // TODO partitioning should probably happen in its own plan ... let mut partitions: HashMap> = HashMap::new(); for batch in batches { + let real_batch = match new_schema.clone() { + Some(new_schema) => { + cast_record_batch(&batch, new_schema, false, true)? + } + None => batch, + }; + let divided = divide_by_partition_values( - schema.clone(), + new_schema.clone().unwrap_or(schema.clone()), partition_columns.clone(), - &batch, + &real_batch, )?; for part in divided { let key = part.partition_values.hive_partition_path(); @@ -594,17 +666,44 @@ impl std::future::IntoFuture for WriteBuilder { } partitions.into_values().collect::>() } else { - vec![batches] + match new_schema { + Some(ref new_schema) => { + let mut new_batches = vec![]; + for batch in batches { + new_batches.push(cast_record_batch( + &batch, + new_schema.clone(), + false, + true, + )?); + } + vec![new_batches] + } + None => vec![batches], + } }; - Ok(Arc::new(MemoryExec::try_new(&data, schema.clone(), None)?) - as Arc) + Ok(Arc::new(MemoryExec::try_new( + &data, + new_schema.unwrap_or(schema).clone(), + None, + )?) as Arc) } } else { Err(WriteError::MissingData) }?; let schema = plan.schema(); - + if this.schema_mode == Some(SchemaMode::Merge) && schema_drift { + if let Some(snapshot) = &this.snapshot { + let schema_struct: StructType = schema.clone().try_into()?; + let schema_action = Action::Metadata(Metadata::try_new( + schema_struct, + partition_columns.clone(), + snapshot.metadata().configuration.clone(), + )?); + actions.push(schema_action); + } + } let state = match this.state { Some(state) => state, None => { @@ -641,10 +740,10 @@ impl std::future::IntoFuture for WriteBuilder { this.write_batch_size, this.writer_properties.clone(), this.safe_cast, - this.overwrite_schema, + this.schema_mode, ) .await?; - actions.extend(add_actions.into_iter().map(Action::Add)); + actions.extend(add_actions); // Collect remove actions if we are overwriting the table if let Some(snapshot) = &this.snapshot { @@ -729,24 +828,42 @@ impl std::future::IntoFuture for WriteBuilder { } } -fn can_cast_batch(from_fields: &Fields, to_fields: &Fields) -> bool { +fn try_cast_batch(from_fields: &Fields, to_fields: &Fields) -> Result<(), ArrowError> { if from_fields.len() != to_fields.len() { - return false; + return Err(ArrowError::SchemaError(format!( + "Cannot cast schema, number of fields does not match: {} vs {}", + from_fields.len(), + to_fields.len() + ))); } - from_fields.iter().all(|f| { - if let Some((_, target_field)) = to_fields.find(f.name()) { - if let (DataType::Struct(fields0), DataType::Struct(fields1)) = - (f.data_type(), target_field.data_type()) - { - can_cast_batch(fields0, fields1) + from_fields + .iter() + .map(|f| { + if let Some((_, target_field)) = to_fields.find(f.name()) { + if let (DataType::Struct(fields0), DataType::Struct(fields1)) = + (f.data_type(), target_field.data_type()) + { + try_cast_batch(fields0, fields1) + } else if !can_cast_types(f.data_type(), target_field.data_type()) { + Err(ArrowError::SchemaError(format!( + "Cannot cast field {} from {} to {}", + f.name(), + f.data_type(), + target_field.data_type() + ))) + } else { + Ok(()) + } } else { - can_cast_types(f.data_type(), target_field.data_type()) + Err(ArrowError::SchemaError(format!( + "Field {} not found in schema", + f.name() + ))) } - } else { - false - } - }) + }) + .collect::, _>>()?; + Ok(()) } #[cfg(test)] @@ -941,23 +1058,25 @@ mod tests { let schema = Arc::new(ArrowSchema::new(vec![Field::new( "value", - DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".to_string().into())), true, )])); let batch = RecordBatch::try_new( Arc::clone(&schema), - vec![Arc::new(TimestampMicrosecondArray::from(vec![Some(10000)]))], + vec![Arc::new( + TimestampMicrosecondArray::from(vec![Some(10000)]).with_timezone("UTC"), + )], ) .unwrap(); let _res = DeltaOps::from(table).write(vec![batch]).await.unwrap(); let expected = [ - "+-------------------------+", - "| value |", - "+-------------------------+", - "| 1970-01-01T00:00:00.010 |", - "| 2023-06-03 15:35:00 |", - "+-------------------------+", + "+--------------------------+", + "| value |", + "+--------------------------+", + "| 1970-01-01T00:00:00.010Z |", + "| 2023-06-03 15:35:00 |", + "+--------------------------+", ]; let actual = get_data(&_res).await; assert_batches_sorted_eq!(&expected, &actual); @@ -997,6 +1116,219 @@ mod tests { assert_eq!(table.get_files_count(), 4) } + #[tokio::test] + async fn test_merge_schema() { + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + for field in batch.schema().fields() { + if field.name() != "modified" { + new_schema_builder.push(field.clone()); + } + } + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); + let new_schema = new_schema_builder.finish(); + let new_fields = new_schema.fields(); + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["id", "value", "inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = RecordBatch::try_new( + Arc::new(new_schema), + vec![ + Arc::new(batch.column_by_name("id").unwrap().clone()), + Arc::new(batch.column_by_name("value").unwrap().clone()), + Arc::new(inserted_by), + ], + ) + .unwrap(); + + let mut table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .with_schema_mode(SchemaMode::Merge) + .await + .unwrap(); + table.load().await.unwrap(); + assert_eq!(table.version(), 1); + let new_schema = table.metadata().unwrap().schema().unwrap(); + let fields = new_schema.fields(); + let names = fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(names, vec!["id", "value", "modified", "inserted_by"]); + } + + #[tokio::test] + async fn test_merge_schema_with_partitions() { + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_partition_columns(vec!["id", "value"]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + for field in batch.schema().fields() { + if field.name() != "modified" { + new_schema_builder.push(field.clone()); + } + } + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); + let new_schema = new_schema_builder.finish(); + let new_fields = new_schema.fields(); + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["id", "value", "inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = RecordBatch::try_new( + Arc::new(new_schema), + vec![ + Arc::new(batch.column_by_name("id").unwrap().clone()), + Arc::new(batch.column_by_name("value").unwrap().clone()), + Arc::new(inserted_by), + ], + ) + .unwrap(); + println!("new_batch: {:?}", new_batch.schema()); + let table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .with_schema_mode(SchemaMode::Merge) + .await + .unwrap(); + + assert_eq!(table.version(), 1); + let new_schema = table.metadata().unwrap().schema().unwrap(); + let fields = new_schema.fields(); + let mut names = fields.iter().map(|f| f.name()).collect::>(); + names.sort(); + assert_eq!(names, vec!["id", "inserted_by", "modified", "value"]); + let part_cols = table.metadata().unwrap().partition_columns.clone(); + assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions + } + + #[tokio::test] + async fn test_overwrite_schema() { + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + for field in batch.schema().fields() { + if field.name() != "modified" { + new_schema_builder.push(field.clone()); + } + } + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); + let new_schema = new_schema_builder.finish(); + let new_fields = new_schema.fields(); + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["id", "value", "inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = RecordBatch::try_new( + Arc::new(new_schema), + vec![ + Arc::new(batch.column_by_name("id").unwrap().clone()), + Arc::new(batch.column_by_name("value").unwrap().clone()), + Arc::new(inserted_by), + ], + ) + .unwrap(); + + let table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .with_schema_mode(SchemaMode::Overwrite) + .await; + assert!(table.is_err()); + } + + #[tokio::test] + async fn test_overwrite_check() { + // If you do not pass a schema mode, we want to check the schema + let batch = get_record_batch(None, false); + let table = DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(SaveMode::ErrorIfExists) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); + + new_schema_builder.push(Field::new("inserted_by", DataType::Utf8, true)); + let new_schema = new_schema_builder.finish(); + let new_fields = new_schema.fields(); + let new_names = new_fields.iter().map(|f| f.name()).collect::>(); + assert_eq!(new_names, vec!["inserted_by"]); + let inserted_by = StringArray::from(vec![ + Some("A1"), + Some("B1"), + None, + Some("B2"), + Some("A3"), + Some("A4"), + None, + None, + Some("B4"), + Some("A5"), + Some("A7"), + ]); + let new_batch = + RecordBatch::try_new(Arc::new(new_schema), vec![Arc::new(inserted_by)]).unwrap(); + + let table = DeltaOps(table) + .write(vec![new_batch]) + .with_save_mode(SaveMode::Append) + .await; + assert!(table.is_err()); + } + #[tokio::test] async fn test_check_invariants() { let batch = get_record_batch(None, false); diff --git a/crates/core/src/protocol/mod.rs b/crates/core/src/protocol/mod.rs index 3be8a734fa..7267529b74 100644 --- a/crates/core/src/protocol/mod.rs +++ b/crates/core/src/protocol/mod.rs @@ -371,6 +371,12 @@ pub enum DeltaOperation { expr: String, }, + /// Drops constraints from a table + DropConstraint { + /// Constraints name + name: String, + }, + /// Merge data with a source data with the following predicate #[serde(rename_all = "camelCase")] Merge { @@ -458,6 +464,7 @@ impl DeltaOperation { DeltaOperation::VacuumStart { .. } => "VACUUM START", DeltaOperation::VacuumEnd { .. } => "VACUUM END", DeltaOperation::AddConstraint { .. } => "ADD CONSTRAINT", + DeltaOperation::DropConstraint { .. } => "DROP CONSTRAINT", } } @@ -496,7 +503,8 @@ impl DeltaOperation { Self::Optimize { .. } | Self::VacuumStart { .. } | Self::VacuumEnd { .. } - | Self::AddConstraint { .. } => false, + | Self::AddConstraint { .. } + | Self::DropConstraint { .. } => false, Self::Create { .. } | Self::FileSystemCheck {} | Self::StreamingUpdate { .. } @@ -1298,15 +1306,21 @@ mod tests { ), ( "min.timestamp", - Arc::new(array::TimestampMicrosecondArray::from(vec![ - TimestampMicrosecondType::parse("2022-10-24T22:59:32.846Z"), - ])), + Arc::new( + array::TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2022-10-24T22:59:32.846Z"), + ]) + .with_timezone("UTC"), + ), ), ( "max.timestamp", - Arc::new(array::TimestampMicrosecondArray::from(vec![ - TimestampMicrosecondType::parse("2022-10-24T22:59:32.846Z"), - ])), + Arc::new( + array::TimestampMicrosecondArray::from(vec![ + TimestampMicrosecondType::parse("2022-10-24T22:59:32.846Z"), + ]) + .with_timezone("UTC"), + ), ), ( "null_count.struct.struct_element", diff --git a/crates/core/src/storage/mod.rs b/crates/core/src/storage/mod.rs index 583c7f0632..9b21645add 100644 --- a/crates/core/src/storage/mod.rs +++ b/crates/core/src/storage/mod.rs @@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize}; use url::Url; pub mod file; +pub mod retry_ext; pub mod utils; use crate::{DeltaResult, DeltaTableError}; @@ -23,6 +24,7 @@ pub use object_store::{ DynObjectStore, Error as ObjectStoreError, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result as ObjectStoreResult, }; +pub use retry_ext::ObjectStoreRetryExt; pub use utils::*; lazy_static! { diff --git a/crates/core/src/storage/retry_ext.rs b/crates/core/src/storage/retry_ext.rs new file mode 100644 index 0000000000..b9f7ebf691 --- /dev/null +++ b/crates/core/src/storage/retry_ext.rs @@ -0,0 +1,82 @@ +//! Retry extension for [`ObjectStore`] + +use bytes::Bytes; +use object_store::{path::Path, Error, ObjectStore, PutResult, Result}; +use tracing::log::*; + +/// Retry extension for [`ObjectStore`] +/// +/// Read-only operations are retried by [`ObjectStore`] internally. However, PUT/DELETE operations +/// are not retried even thought they are technically idempotent. [`ObjectStore`] does not retry +/// those operations because having preconditions may produce different results for the same +/// request. PUT/DELETE operations without preconditions are idempotent and can be retried. +/// Unfortunately, [`ObjectStore`]'s retry mechanism only works on HTTP request level, thus there +/// is no way to distinguish whether a request has preconditions or not. +/// +/// This trait provides additional methods for working with [`ObjectStore`] that automatically retry +/// unconditional operations when they fail. +/// +/// See also: +/// - https://github.com/apache/arrow-rs/pull/5278 +#[async_trait::async_trait] +pub trait ObjectStoreRetryExt: ObjectStore { + /// Save the provided bytes to the specified location + /// + /// The operation is guaranteed to be atomic, it will either successfully write the entirety of + /// bytes to location, or fail. No clients should be able to observe a partially written object + /// + /// Note that `put_with_opts` may have precondition semantics, and thus may not be retriable. + async fn put_with_retries( + &self, + location: &Path, + bytes: Bytes, + max_retries: usize, + ) -> Result { + let mut attempt_number = 1; + while attempt_number <= max_retries { + match self.put(location, bytes.clone()).await { + Ok(result) => return Ok(result), + Err(err) if attempt_number == max_retries => { + return Err(err); + } + Err(Error::Generic { store, source }) => { + debug!( + "put_with_retries attempt {} failed: {} {}", + attempt_number, store, source + ); + attempt_number += 1; + } + Err(err) => { + return Err(err); + } + } + } + unreachable!("loop yields Ok or Err in body when attempt_number = max_retries") + } + + /// Delete the object at the specified location + async fn delete_with_retries(&self, location: &Path, max_retries: usize) -> Result<()> { + let mut attempt_number = 1; + while attempt_number <= max_retries { + match self.delete(location).await { + Ok(()) | Err(Error::NotFound { .. }) => return Ok(()), + Err(err) if attempt_number == max_retries => { + return Err(err); + } + Err(Error::Generic { store, source }) => { + debug!( + "delete_with_retries attempt {} failed: {} {}", + attempt_number, store, source + ); + attempt_number += 1; + } + Err(err) => { + return Err(err); + } + } + } + unreachable!("loop yields Ok or Err in body when attempt_number = max_retries") + } +} + +impl ObjectStoreRetryExt for T {} diff --git a/crates/core/src/table/config.rs b/crates/core/src/table/config.rs index 24b11a01a4..ff33c1597b 100644 --- a/crates/core/src/table/config.rs +++ b/crates/core/src/table/config.rs @@ -249,7 +249,7 @@ impl<'a> TableConfig<'a> { ( "true to enable deletion vectors and predictive I/O for updates.", DeltaConfigKey::EnableDeletionVectors, - enable_deletio0n_vectors, + enable_deletion_vectors, bool, // in databricks the default is dependent on the workspace settings and runtime version // https://learn.microsoft.com/en-us/azure/databricks/administration-guide/workspace-settings/deletion-vectors @@ -289,7 +289,7 @@ impl<'a> TableConfig<'a> { /// than this value. Otherwise, the query may not be able to restart, as it must still read old files. pub fn deleted_file_retention_duration(&self) -> Duration { lazy_static! { - static ref DEFAULT_DURATION: Duration = parse_interval("interval 1 week").unwrap(); + static ref DEFAULT_DURATION: Duration = parse_interval("interval 1 weeks").unwrap(); } self.0 .get(DeltaConfigKey::DeletedFileRetentionDuration.as_ref()) @@ -305,7 +305,7 @@ impl<'a> TableConfig<'a> { /// constant time. Operations on history are parallel but will become more expensive as the log size increases. pub fn log_retention_duration(&self) -> Duration { lazy_static! { - static ref DEFAULT_DURATION: Duration = parse_interval("interval 30 day").unwrap(); + static ref DEFAULT_DURATION: Duration = parse_interval("interval 30 days").unwrap(); } self.0 .get(DeltaConfigKey::LogRetentionDuration.as_ref()) @@ -525,14 +525,14 @@ fn parse_interval(value: &str) -> Result { let number = number as u64; let duration = match it.next().ok_or_else(not_an_interval)? { - "nanosecond" => Duration::from_nanos(number), - "microsecond" => Duration::from_micros(number), - "millisecond" => Duration::from_millis(number), - "second" => Duration::from_secs(number), - "minute" => Duration::from_secs(number * SECONDS_PER_MINUTE), - "hour" => Duration::from_secs(number * SECONDS_PER_HOUR), - "day" => Duration::from_secs(number * SECONDS_PER_DAY), - "week" => Duration::from_secs(number * SECONDS_PER_WEEK), + "nanosecond" | "nanoseconds" => Duration::from_nanos(number), + "microsecond" | "microseconds" => Duration::from_micros(number), + "millisecond" | "milliseconds" => Duration::from_millis(number), + "second" | "seconds" => Duration::from_secs(number), + "minute" | "minutes" => Duration::from_secs(number * SECONDS_PER_MINUTE), + "hour" | "hours" => Duration::from_secs(number * SECONDS_PER_HOUR), + "day" | "days" => Duration::from_secs(number * SECONDS_PER_DAY), + "week" | "weeks" => Duration::from_secs(number * SECONDS_PER_WEEK), unit => { return Err(DeltaConfigError::Validation(format!( "Unknown unit '{unit}'" @@ -620,36 +620,76 @@ mod tests { Duration::from_nanos(123) ); + assert_eq!( + parse_interval("interval 123 nanoseconds").unwrap(), + Duration::from_nanos(123) + ); + assert_eq!( parse_interval("interval 123 microsecond").unwrap(), Duration::from_micros(123) ); + assert_eq!( + parse_interval("interval 123 microseconds").unwrap(), + Duration::from_micros(123) + ); + assert_eq!( parse_interval("interval 123 millisecond").unwrap(), Duration::from_millis(123) ); + assert_eq!( + parse_interval("interval 123 milliseconds").unwrap(), + Duration::from_millis(123) + ); + assert_eq!( parse_interval("interval 123 second").unwrap(), Duration::from_secs(123) ); + assert_eq!( + parse_interval("interval 123 seconds").unwrap(), + Duration::from_secs(123) + ); + assert_eq!( parse_interval("interval 123 minute").unwrap(), Duration::from_secs(123 * 60) ); + assert_eq!( + parse_interval("interval 123 minutes").unwrap(), + Duration::from_secs(123 * 60) + ); + assert_eq!( parse_interval("interval 123 hour").unwrap(), Duration::from_secs(123 * 3600) ); + assert_eq!( + parse_interval("interval 123 hours").unwrap(), + Duration::from_secs(123 * 3600) + ); + assert_eq!( parse_interval("interval 123 day").unwrap(), Duration::from_secs(123 * 86400) ); + assert_eq!( + parse_interval("interval 123 days").unwrap(), + Duration::from_secs(123 * 86400) + ); + + assert_eq!( + parse_interval("interval 123 week").unwrap(), + Duration::from_secs(123 * 604800) + ); + assert_eq!( parse_interval("interval 123 week").unwrap(), Duration::from_secs(123 * 604800) diff --git a/crates/core/src/table/state_arrow.rs b/crates/core/src/table/state_arrow.rs index 143ab23d1c..2c9ed74ce3 100644 --- a/crates/core/src/table/state_arrow.rs +++ b/crates/core/src/table/state_arrow.rs @@ -737,8 +737,8 @@ fn json_value_to_array_general<'a>( .map(|value| value.and_then(|value| value.as_str().map(|value| value.as_bytes()))) .collect_vec(), ))), - DataType::Timestamp(TimeUnit::Microsecond, None) => { - Ok(Arc::new(TimestampMicrosecondArray::from( + DataType::Timestamp(TimeUnit::Microsecond, tz) => match tz { + None => Ok(Arc::new(TimestampMicrosecondArray::from( values .map(|value| { value.and_then(|value| { @@ -746,13 +746,32 @@ fn json_value_to_array_general<'a>( }) }) .collect_vec(), - ))) - } + ))), + Some(tz_str) if tz_str.as_ref() == "UTC" => Ok(Arc::new( + TimestampMicrosecondArray::from( + values + .map(|value| { + value.and_then(|value| { + value.as_str().and_then(TimestampMicrosecondType::parse) + }) + }) + .collect_vec(), + ) + .with_timezone("UTC"), + )), + _ => Err(DeltaTableError::Generic(format!( + "Invalid datatype {}", + datatype + ))), + }, DataType::Date32 => Ok(Arc::new(Date32Array::from( values .map(|value| value.and_then(|value| value.as_str().and_then(Date32Type::parse))) .collect_vec(), ))), - _ => Err(DeltaTableError::Generic("Invalid datatype".to_string())), + _ => Err(DeltaTableError::Generic(format!( + "Invalid datatype {}", + datatype + ))), } } diff --git a/crates/core/src/writer/json.rs b/crates/core/src/writer/json.rs index 6740bc0204..8cc908320e 100644 --- a/crates/core/src/writer/json.rs +++ b/crates/core/src/writer/json.rs @@ -25,6 +25,7 @@ use super::utils::{ use super::{DeltaWriter, DeltaWriterError, WriteMode}; use crate::errors::DeltaTableError; use crate::kernel::{Add, PartitionsExt, Scalar, StructType}; +use crate::storage::ObjectStoreRetryExt; use crate::table::builder::DeltaTableBuilder; use crate::writer::utils::ShareableBuffer; use crate::DeltaTable; @@ -360,7 +361,7 @@ impl DeltaWriter> for JsonWriter { let path = next_data_path(&prefix, 0, &uuid, &writer.writer_properties); let obj_bytes = Bytes::from(writer.buffer.to_vec()); let file_size = obj_bytes.len() as i64; - self.storage.put(&path, obj_bytes).await?; + self.storage.put_with_retries(&path, obj_bytes, 15).await?; actions.push(create_add( &writer.partition_values, diff --git a/crates/core/src/writer/record_batch.rs b/crates/core/src/writer/record_batch.rs index c62fc9b560..5c8fb57509 100644 --- a/crates/core/src/writer/record_batch.rs +++ b/crates/core/src/writer/record_batch.rs @@ -29,6 +29,8 @@ use super::utils::{ use super::{DeltaWriter, DeltaWriterError, WriteMode}; use crate::errors::DeltaTableError; use crate::kernel::{Action, Add, PartitionsExt, Scalar, StructType}; +use crate::operations::cast::merge_schema; +use crate::storage::ObjectStoreRetryExt; use crate::table::builder::DeltaTableBuilder; use crate::DeltaTable; @@ -215,7 +217,7 @@ impl DeltaWriter for RecordBatchWriter { let path = next_data_path(&prefix, 0, &uuid, &writer.writer_properties); let obj_bytes = Bytes::from(writer.buffer.to_vec()); let file_size = obj_bytes.len() as i64; - self.storage.put(&path, obj_bytes).await?; + self.storage.put_with_retries(&path, obj_bytes, 15).await?; actions.push(create_add( &writer.partition_values, @@ -304,10 +306,10 @@ impl PartitionWriter { WriteMode::MergeSchema => { debug!("The writer and record batch schemas do not match, merging"); - let merged = ArrowSchema::try_merge(vec![ + let merged = merge_schema( self.arrow_schema.as_ref().clone(), record_batch.schema().as_ref().clone(), - ])?; + )?; self.arrow_schema = Arc::new(merged); let mut cols = vec![]; diff --git a/crates/core/src/writer/stats.rs b/crates/core/src/writer/stats.rs index 40e7b303cc..20201c949c 100644 --- a/crates/core/src/writer/stats.rs +++ b/crates/core/src/writer/stats.rs @@ -645,7 +645,6 @@ mod tests { } } - #[ignore] #[tokio::test] async fn test_delta_stats() { let temp_dir = tempfile::tempdir().unwrap(); diff --git a/crates/core/tests/command_optimize.rs b/crates/core/tests/command_optimize.rs index e26e642e73..deda18928b 100644 --- a/crates/core/tests/command_optimize.rs +++ b/crates/core/tests/command_optimize.rs @@ -416,7 +416,6 @@ async fn test_commit_interval() -> Result<(), Box> { } #[tokio::test] -#[ignore] /// Validate that bin packing is idempotent. async fn test_idempotent() -> Result<(), Box> { //TODO: Compression makes it hard to get the target file size... diff --git a/crates/core/tests/integration_datafusion.rs b/crates/core/tests/integration_datafusion.rs index 90fc3ea9fa..4be66534fe 100644 --- a/crates/core/tests/integration_datafusion.rs +++ b/crates/core/tests/integration_datafusion.rs @@ -982,7 +982,7 @@ mod local { Ok(()) } - + #[ignore] #[tokio::test] async fn test_issue_1374() -> Result<()> { let ctx = SessionContext::new(); diff --git a/crates/deltalake/Cargo.toml b/crates/deltalake/Cargo.toml index 3104080489..c541cc1284 100644 --- a/crates/deltalake/Cargo.toml +++ b/crates/deltalake/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deltalake" -version = "0.17.0" +version = "0.17.1" authors.workspace = true keywords.workspace = true readme.workspace = true @@ -11,8 +11,13 @@ license.workspace = true repository.workspace = true rust-version.workspace = true +[package.metadata.docs.rs] +# We cannot use all_features because TLS features are mutually exclusive. +# We cannot use hdfs feature because it requires Java to be installed. +features = ["azure", "datafusion", "gcs", "hdfs", "json", "python", "s3", "unity-experimental"] + [dependencies] -deltalake-core = { version = "0.17.0", path = "../core" } +deltalake-core = { version = "0.17.1", path = "../core" } deltalake-aws = { version = "0.1.0", path = "../aws", default-features = false, optional = true } deltalake-azure = { version = "0.1.0", path = "../azure", optional = true } deltalake-gcp = { version = "0.1.0", path = "../gcp", optional = true } diff --git a/docs/delta-lake-best-practices.md b/docs/delta-lake-best-practices.md new file mode 100644 index 0000000000..c83c7db7a3 --- /dev/null +++ b/docs/delta-lake-best-practices.md @@ -0,0 +1,121 @@ +# Delta Lake Best Practices + +This page outlines Delta Lake best practices. + +You should consider several factors to optimize the performance of your Delta tables and minimize costs. + +The Delta Lake best practices depend on your data ingestion into the Delta table and query patterns. You must understand your data and how users run queries to best leverage Delta Lake. + +## Compacting small files + +Delta tables work best when the files are “right-sized”. Files that are too small create I/O overhead. Files that are too large limit the impact of file skipping (a critical query optimization). + +Delta tables can accumulate a lot of small files, especially if you’re frequently writing small amounts of data. If your table has many small files, you should run a small compaction operation to consolidate all the tiny files into “right-sized” files. + +It’s generally best for files in a Delta table to be between 100MB and 1GB, but that can vary based on the overall size of the table and the query patterns. + +Delta Lake makes it easy to [compact the small files](https://delta-io.github.io/delta-rs/usage/optimize/small-file-compaction-with-optimize/). + +## Optimizing table layout + +You can colocate similar data in the same files to make file skipping more effective. Delta Lake supports [Z Ordering](https://delta-io.github.io/delta-rs/usage/optimize/delta-lake-z-order/), which can colocate similar data in the same files. + +Z Ordering can yield impressive performance gains for low-cardinality columns but also works well for high-cardinality columns. This is an advantage compared to Hive-style partitioning, which is only suitable for low-cardinality columns. + +You must analyze the most common query patterns and Z Order your dataset based on the columns allowing the most file skipping. The ability to colocate data in the same files decreases when you add more Z Order columns. + +Let’s look at Hive-style partitioning, another way to colocate data in the same files. You can also use Hive-style partitioning in conjunction with Z Ordering. + +## Partitioning datasets + +You can partition your Delta tables, which separates the data by one or more partition keys into separate folders. Partitioning can be an excellent performance optimization (when you filter on the partition key) and is a good way to sidestep concurrency conflict issues. + +Hive-style partitioning also has some significant downsides. + +* It’s only suitable for low-cardinality columns. +* It can create many small files, especially if you use the wrong partition key or frequently update the Delta table. +* It can cause some queries that don’t rely on the partition key to run slower (because of the excessive number of small files). A large number of small files is problematic for I/O throughput. + +Hive-style partitioning can be a great data management tactic and a fantastic option for many Delta tables. Beware of the downsides before partitioning your tables. + +You can use Hive-style partitioning in conjunction with Z Ordering. You can partition a table by one column and Z Order by another. They’re different tactics that aim to help you skip more files and run queries faster. + +Let’s look at some of the built-in Delta features that help maintain the integrity of your tables. + +## Use appropriate quality controls + +Delta Lake supports schema enforcement and column constraints to protect the integrity of your data. + +Delta Lake enabled schema enforcement by default, so you can only append data to an existing table with the same exact schema. You can bypass schema enforcement by enabling schema evolution, which allows you to append mismatched schemas to a table. + +You should only enable schema evolution when you want to allow the schema of your table to change. You should not enable schema evolution if you don’t want this flexibility. Schema enforcement is a good default setting. + +Column-level constraints prevent you from appending data that fail SQL predicates. For example, you may add a constraint that requires all the values in the `age` column of a table to be positive. + +You should add column constraints to your table whenever you want a column only to include values that satisfy a predicate. + +No data is appended when you apply a constraint and a row check fails. For example, if you try to append 100 rows of data to a table and one row has a failing check, then no data is added. + +When you have column constraints, it’s often a good idea to append the failing data to a “quarantine table” and the passing data to the main Delta table. Or you can filter out the failing rows and just append the passing rows. Keeping a history of the failing rows in a quarantine table is helpful for debugging. + +See here to learn more about [Delta Lake constraints](https://delta-io.github.io/delta-rs/usage/constraints/). + +## Best practices for DML operations + +DML operations like deleting, updating, and merging write existing data in new files and mark existing files for deletion in the transaction log. Rewriting data files is expensive, so you want to minimize the number of rewritten files when you run DML operations. + +Delta Lake supports a table feature called deletion vectors that implements DML transactions more efficiently under the hood. Enabling deletion vectors is usually the best way to make DML operations run faster. Note: delta-rs doesn’t support deletion vectors yet. + +You should periodically purge deletion vectors because they can accumulate and slow subsequent read operations. Once you enable the feature, you must purge the deletion vectors in your table with an appropriate cadence. + +## Use vacuum to save storage costs + +Delta Lake supports transactions, which necessitates keeping old versions of data in storage, even the files marked for removal in the transactions log. + +Keeping old versions of Delta tables in storage is often desirable because it allows for versioned data, time travel, and rolling back tables to a previous state. + +If you don’t want to leverage older versions of a table, then you should remove the legacy files from storage with the vacuum command. Vacuum will remove all files older than the table retention period and marked for removal in the transaction log. + +You only need to vacuum when you perform operations that mark files for removal in the transaction log. An append-only table doesn’t create legacy files that need to be vacuumed. + +Create a good vacuum strategy for your tables to minimize your storage costs. + +## Delta Lake best practices to minimize costs + +Delta Lake helps you minimize costs in many ways: + +* It's a free, open source format (based on Parquet). It's not a proprietary format that you need to pay for. +* Delta tables store column-level min/max values in the transaction log, allowing file skipping. +* Delta tables can be optimized (small file compaction, Z Ordering, etc.), so your queries run faster. When your queries run faster, then you pay less on compute. +* Deletion vectors let you perform DML operations (delete, update, merge) much faster. If your delete operation runs 100x faster, then you pay 100x less compute. +* It's easy to remove legacy files from storage with VACUUM, which minimizes storage costs. + +You should understand your organization’s query patterns and use these features to minimize the overall cost. You need to assess tradeoffs. For example, Z Ordering is a computation that costs money, but it can save you lots of money in the long run if all your subsequent queries run a lot faster and use less compute. + +## Collect metadata stats on columns used for file skipping + +Delta tables don’t always store each column's min/max values. Some Delta Lake implementations only store min/max values for the first 32 columns in the table, for example. + +Delta Lake can only apply file-skipping when it has min/max values for the relevant columns stored in the transaction log. Suppose you’re running a filtering operation on `col_a,` for example. Delta Lake can only apply file skipping when the transaction log stores `col_a` min/max metadata. + +Ensure the transaction log stores metadata stats for all the columns that benefit from file skipping. + +## Don’t collect column metadata when it’s unnecessary + +It takes some time to compute column statistics when writing files, and it isn’t worth the effort if you cannot use the column for file skipping. + +Suppose you have a table column containing a long string of arbitrary text. It’s unlikely that this column would ever provide any data-skipping benefits. So, you can just avoid the overhead of collecting the statistics for this particular column. + +## Additional reading + +Delta Lake relies on transactions, and you should check out [this page to learn more](https://delta-io.github.io/delta-rs/how-delta-lake-works/delta-lake-acid-transactions/). + +Many Delta Lake performance benefits rely on [file skipping](https://delta-io.github.io/delta-rs/how-delta-lake-works/delta-lake-file-skipping/), which you should understand well to get the most out of Delta. + +## Conclusion + +Delta Lake is a powerful technology that makes your data pipelines more reliable, saves money, and makes everyday data processing tasks easy. + +You need to learn how Delta Lake works at a high level to leverage Delta's power fully. You will not be able to leverage Delta Lake’s full performance potential if your table has improperly sized files or if you’re not colocating data in the same files to maximize data skipping, for example. + +Luckily, there are only a few details that are important to learn. You don’t need to know the implementation details - just the essential high-level concepts. diff --git a/docs/integrations/delta-lake-dagster.md b/docs/integrations/delta-lake-dagster.md new file mode 100644 index 0000000000..0fe413c7a0 --- /dev/null +++ b/docs/integrations/delta-lake-dagster.md @@ -0,0 +1,278 @@ +# Using Delta Lake with Dagster¶ + +Delta Lake is a great storage format for Dagster workflows. This page will explain why and how to use Delta Lake with Dagster. + +You will learn how to use the Delta Lake I/O Manager to read and write your Dagster Software-Defined Assets (SDAs). You will also learn about the unique advantages Delta Lake offers the Dagster community. + +Here are some of the benefits that Delta Lake provides Dagster users: +- native PyArrow integration for lazy computation of large datasets, +- more efficient querying with file skipping via Z Ordering and liquid clustering +- built-in vacuuming to remove unnecessary files and versions +- ACID transactions for reliable writes +- smooth versioning integration so that versions can be use to trigger downstream updates. +- surfacing table stats based on the file statistics + + +## Dagster I/O Managers +Dagster uses [I/O Managers](https://docs.dagster.io/concepts/io-management/io-managers#overview) to simplify data reads and writes. I/O Managers help you reduce boilerplate code by storing Dagster Asset and Op outputs and loading them as inputs to downstream objects. They make it easy to change where and how your data is stored. + +You only need to define your I/O Manager and its settings (such as storage location and schema) once and the I/O Manager will take care of correctly reading and writing all your Dagster Assets automatically. + +If you need lower-level access than the Dagster I/O Managers provide, take a look at the Delta Table Resource. + +## The Delta Lake I/O Manager +You can easily read and write Delta Lake Tables from Dagster by using the `DeltaLakeIOManager()`. + +Install the DeltaLakeIOManager: + +``` +pip install dagster-deltalake +``` + +Next, configure the following settings in your project’s `__init__.py` file: +- `io_manager`: set this to `DeltaLakeIOManager()`, this sets the default I/O Manager for all your Assets + +Within the DeltaLakeIOManager, define: +- `root_uri`: the root path where your Delta Tables will be created +- `storage_options`: configuration for accessing storage location +- `schema`: name of schema to use (optional, defaults to public) + +``` +defs = Definitions( + assets=all_assets, + resources={ + "io_manager": DeltaLakePyarrowIOManager( + root_uri="path/to/deltalake", + storage_options=LocalConfig(), + schema="dagster_deltalake", + ), + }, +) +``` + +Now, when you materialize an Asset, it will be saved as a Delta Lake in a folder `dagster_deltalake/asset_name` under the root directory `path/to/deltalake`. + +The default Delta Lake I/O Manager supports Arrow reads and writes. You can also use the Delta Lake I/O Manager with [pandas](#using-delta-lake-and-dagster-with-pandas) or [polars](#using-delta-lake-and-dagster-with-polars). + +## Creating Delta Lake Tables with Dagster +You don’t need to do anything else to store your Dagster Assets as Delta Lake tables. The I/O Manager will handle storing and loading your Assets as Delta Lake tables from now on. + +You can proceed to write Dagster code as you normally would. For example, you can create an Asset that reads in some toy data about animals and writes it out to an Arrow Table: + +``` +import pyarrow as pa +from pyarrow import csv + +from dagster import asset + +@asset +def raw_dataset() -> pa.Table: + n_legs = pa.array([2, 4, None, 100]) + animals = pa.array(["Flamingo", "Horse", "Brittle stars", "Centipede"]) + data = {'n_legs': n_legs, 'animals': animals} + + return pa.Table.from_pydict(data) +``` + +When you materialize the Asset defined above (using the config settings defined earlier), the Delta Lake I/O Manager will create the table `dagster_deltalake/iris_dataset` if it doesn’t exist yet. + +### Overwrites when Rematerializing Assets +If the table does already exist at the specified location, the Delta Lake I/O Manager will perform an overwrite. Delta Lake’s transaction log maintains a record of all changes to your Delta Lake tables. You can inspect the record of changes to your Delta Lake tables by taking a look at these transaction logs. + +## Loading Delta Lake Tables in Downstream Assets +You can use Assets stored as Delta Lake tables as input to downstream Assets. Dagster and the Delta Lake I/O Manager make this easy for you. + +You can write Dagster code as you normally would. Pass the upstream Asset as an argument to the downstream object to set up the dependency. Make sure to define the correct data type. + +The Delta Lake I/O Manager will handle reading and writing the data from your Delta Lake. + +``` +import pyarrow as pa +from dagster import asset + +# ... raw_dataset asset is defined here ... + +@asset +def clean_dataset(raw_dataset: pa.Table) -> pa.Table: + return raw_dataset.drop_null() +``` + +## Reading Existing Delta Lake Tables into Dagster +You can make existing Delta Lake tables (that were not created in Dagster) available to your Dagster assets. Use the `SourceAsset` object and pass the table name as the key argument: + +``` +from dagster import SourceAsset + +iris_harvest_data = SourceAsset(key="more_animal_data") +``` + +This will load a table `more_animal_data` located at `/` as configured in the Definitions object above (see [Delta Lake I/O Manager](#the-delta-lake-io-manager) section). + +## Column Pruning +You can often improve the efficiency of your computations by only loading specific columns of your Delta table. This is called column pruning. + +With the Delta Lake I/O manager, you can select specific columns to load defining the `columns` in the `metadata` parameter of the `AssetIn` that loads the upstream Asset: + +``` +import pyarrow as pa +from dagster import AssetIn, asset + +# this example uses the clean_dataset Asset defined earlier + +@asset( + ins={ + "mammal_bool": AssetIn( + key="clean_dataset", + metadata={"columns": ["is_mammal", "animals"]}, + ) + } +) +def mammal_data(mammal_bool: pa.Table) -> pa.Table: + mammals = mammal_bool["is_mammal"].cast("bool") + animals = mammal_bool["animals"] + data = {"mammal_bool": mammals, "animals": animals} + return pa.Table.from_pydict(data) +``` + +Here, we select only the `sepal_length_cm` and `sepal_width_cm` columns from the `iris_dataset` table and load them into an `AssetIn` object called `iris_sepal`. This AssetIn object is used to create a new Asset `sepal_data`, containing only the selected columns. + +## Working with Partitioned Assets +Partitioning is an important feature of Delta Lake that can make your computations more efficient. The Delta Lake I/O manager helps you read and write partitioned data easily. You can work with static partitions, time-based partitions, multi-partitions, and dynamic partitions. + +For example, you can partition the Iris dataset on the `species` column as follows: + +``` +import pyarrow as pa + +from dagster import StaticPartitionsDefinition, asset + +@asset( + partitions_def=StaticPartitionsDefinition( + ["Human", "Horse",] + ), + metadata={"partition_expr": "n_legs"}, +) +def dataset_partitioned( + context, + clean_dataset: pa.Table, + ) -> pa.Table: + animals = context.asset_partition_key_for_output() + table = clean_dataset + + return table.filter(pc.field("animals") == animals) +``` + +To partition your data, make sure to include the relevant `partitions_def` and `metadata` arguments to the `@asset` decorator. Refer to the Dagster documentation on [partitioning assets](https://docs.dagster.io/concepts/partitions-schedules-sensors/partitioning-assets) for more information. + +## Using Delta Lake and Dagster with Pandas +To read and write data to Delta Lake using pandas, use the `DeltaLakePandasIOManager()`. + +You will need to install it using: + +``` +pip install dagster-deltalake-pandas +``` + +In your `Definitions` object, change the `io_manager` to `DeltaLakePandasIOManager()`: + +``` +from dagster_deltalake_pandas import DeltaLakePandasIOManager + + +defs = Definitions( + assets=all_assets, + resources={ + "io_manager": DeltaLakePandasIOManager( + root_uri="path/to/deltalake", + storage_options=LocalConfig(), + schema="dagster_deltalake", + ), + }, +) +``` + +Now you can read and write Dagster Assets defined as pandas DataFrames in Delta Lake format. For example: + +``` +import pandas as pd +from dagster import asset + +@asset +def iris_dataset() -> pd.DataFrame: + return pd.read_csv( + "https://docs.dagster.io/assets/iris.csv", + names=[ + "sepal_length_cm", + "sepal_width_cm", + "petal_length_cm", + "petal_width_cm", + "species", + ], + ) +``` + +## Using Delta Lake and Dagster with Polars +To read and write data to Delta Lake using pandas, use the `DeltaLakePolarsIOManager()`. + +You will need to install it using: + +``` +pip install dagster-deltalake-polars +``` + +In your `Definitions` object, change the `io_manager` to `DeltaLakePolarsIOManager()`: + +``` +from dagster_polars import DeltaLakePolarsIOManager + +defs = Definitions( + assets=all_assets, + resources={ + "io_manager": DeltaLakePolarsIOManager( + root_uri="path/to/deltalake", + storage_options=LocalConfig(), + schema="dagster_deltalake", + ), + }, +) +``` + +Now you can read and write Dagster Assets defined as Polars DataFrames in Delta Lake format. For example: + +``` +import polars as pl +from dagster import asset + + +@asset +def iris_dataset() -> pl.DataFrame: + return pl.read_csv( + "https://docs.dagster.io/assets/iris.csv", + new_columns=[ + "sepal_length_cm", + "sepal_width_cm", + "petal_length_cm", + "petal_width_cm", + "species", + ], + has_header=False +) +``` + +## Delta Lake Table Resource +I/O managers are a helpful tool in many common usage situations. But when you need lower-level access, the I/O Manager might not be the right tool to use. In these cases you may want to use the Delta Lake Table Resource. + +The Delta Lake Table Resource is a low-level access method to the table object. It gives you more fine-grained control and allows for modeling of more complex data. You can also use the Table Resource to run optimization and vacuuming jobs. + +## Schema and Constraint Enforcement +Delta Lake provides built-in checks to ensure schema consistency when appending data to a table, as well as the ability to evolve the schema. This is a great feature for the Dagster community as it prevents bad data from being appended to tables, ensuring data consistency and accuracy. + +Read more about how to add constraints to a table in [the Delta Lake documentation](https://delta-io.github.io/delta-rs/usage/constraints/). + +## Z-Ordering +Delta Lake offers Z-ordering functionality to colocate similar data in the same files. This can make your Delta Table queries much more efficient via file skipping. Dagster users can now benefit from this great feature through the Delta Lake I/O Manager. + +Read more about Z-Ordering on [the Delta Lake blog](https://delta.io/blog/2023-06-03-delta-lake-z-order/). + +## Contribute +To contribute to the Delta Lake and Dagster integration, go to [link] diff --git a/docs/integrations/delta-lake-dask.md b/docs/integrations/delta-lake-dask.md index d3f0ebaf18..cb3cab27f7 100644 --- a/docs/integrations/delta-lake-dask.md +++ b/docs/integrations/delta-lake-dask.md @@ -58,7 +58,8 @@ You can read in specific versions of Delta tables by specifying a `version` numb ddf = ddt.read_deltalake(delta_path, version=3) # with specific datetime -ddt.read_deltalake(delta_path, datetime="2018-12-19T16:39:57-08:00")``` +ddt.read_deltalake(delta_path, datetime="2018-12-19T16:39:57-08:00") +``` `dask-deltatable` also supports reading from remote sources like S3 with: diff --git a/docs/integrations/delta-lake-pandas.md b/docs/integrations/delta-lake-pandas.md index b14c1bd45b..ca60362838 100644 --- a/docs/integrations/delta-lake-pandas.md +++ b/docs/integrations/delta-lake-pandas.md @@ -250,10 +250,10 @@ Schema enforcement protects your table from getting corrupted by appending data ## Overwriting schema of table -You can overwrite the table contents and schema by setting the `overwrite_schema` option. Here's how to overwrite the table contents: +You can overwrite the table contents and schema by setting the `schema_mode` option. Here's how to overwrite the table contents: ```python -write_deltalake("tmp/some-table", df, mode="overwrite", overwrite_schema=True) +write_deltalake("tmp/some-table", df, mode="overwrite", schema_mode="overwrite") ``` Here are the contents of the table after the values and schema have been overwritten: @@ -267,6 +267,8 @@ Here are the contents of the table after the values and schema have been overwri +-------+----------+ ``` +If you want the schema to be merged instead, specify schema_mode="merge". + ## In-memory vs. in-storage data changes It's important to distinguish between data stored in-memory and data stored on disk when understanding the functionality offered by Delta Lake. diff --git a/docs/src/python/check_constraints.py b/docs/src/python/check_constraints.py index 16fb8bf374..1bfa62d970 100644 --- a/docs/src/python/check_constraints.py +++ b/docs/src/python/check_constraints.py @@ -13,9 +13,11 @@ def add_constraint(): def add_data(): # --8<-- [start:add_data] - from deltalake import write_deltalake + from deltalake import write_deltalake, DeltaTable import pandas as pd + dt = DeltaTable("../rust/tests/data/simple_table") + df = pd.DataFrame({"id": [-1]}) write_deltalake(dt, df, mode="append", engine="rust") # _internal.DeltaProtocolError: Invariant violations: ["Check or Invariant (id > 0) violated by value in row: [-1]"] diff --git a/docs/usage/writing/index.md b/docs/usage/writing/index.md index 816e874993..09f3517f6f 100644 --- a/docs/usage/writing/index.md +++ b/docs/usage/writing/index.md @@ -23,7 +23,9 @@ of Spark's `pyspark.sql.DataFrameWriter.saveAsTable` DataFrame method. To overwr `write_deltalake` will raise `ValueError` if the schema of the data passed to it differs from the existing table's schema. If you wish to -alter the schema as part of an overwrite pass in `overwrite_schema=True`. +alter the schema as part of an overwrite pass in `schema_mode="overwrite"` or `schema_mode="merge"`. +`schema_mode="overwrite"` will completely overwrite the schema, even if columns are dropped; merge will append the new columns +and fill missing columns with `null`. `schema_mode="merge"` is also supported on append operations. ## Overwriting a partition diff --git a/mkdocs.yml b/mkdocs.yml index 204069aef7..f751755009 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,6 +49,7 @@ nav: - Home: index.md - Why Use Delta Lake: why-use-delta-lake.md - Delta Lake for big and small data: delta-lake-big-data-small-data.md + - Best practices: delta-lake-best-practices.md - Usage: - Installation: usage/installation.md - Overview: usage/overview.md @@ -80,6 +81,8 @@ nav: - api/exceptions.md - Integrations: - Arrow: integrations/delta-lake-arrow.md + - Dagster: integrations/delta-lake-dagster.md + - Dask: integrations/delta-lake-dask.md - DataFusion: integrations/delta-lake-datafusion.md - pandas: integrations/delta-lake-pandas.md - Polars: integrations/delta-lake-polars.md diff --git a/python/Cargo.toml b/python/Cargo.toml index e938103921..c7fa5ca2b1 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deltalake-python" -version = "0.15.3" +version = "0.16.0" authors = ["Qingping Hou ", "Will Jones "] homepage = "https://github.com/delta-io/delta-rs" license = "Apache-2.0" diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 6319412c8e..b16d468571 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -44,7 +44,7 @@ class RawDeltaTable: def version(self) -> int: ... def get_latest_version(self) -> int: ... def metadata(self) -> RawDeltaTableMetaData: ... - def protocol_versions(self) -> List[int]: ... + def protocol_versions(self) -> List[Any]: ... def load_version(self, version: int) -> None: ... def load_with_datetime(self, ds: str) -> None: ... def files_by_partitions( @@ -84,6 +84,12 @@ class RawDeltaTable: constraints: Dict[str, str], custom_metadata: Optional[Dict[str, str]], ) -> None: ... + def drop_constraints( + self, + name: str, + raise_if_not_exists: bool, + custom_metadata: Optional[Dict[str, str]], + ) -> None: ... def restore( self, target: Optional[Any], @@ -168,7 +174,7 @@ def write_to_deltalake( partition_by: Optional[List[str]], mode: str, max_rows_per_group: int, - overwrite_schema: bool, + schema_mode: Optional[str], predicate: Optional[str], name: Optional[str], description: Optional[str], @@ -226,6 +232,7 @@ class PrimitiveType: * "binary", * "date", * "timestamp", + * "timestampNtz", * "decimal(, )" Args: @@ -789,6 +796,11 @@ class DeltaProtocolError(DeltaError): pass +class SchemaMismatchError(DeltaError): + """Raised when a schema mismatch is detected.""" + + pass + FilterLiteralType = Tuple[str, str, Any] FilterConjunctionType = List[FilterLiteralType] FilterDNFType = List[FilterConjunctionType] diff --git a/python/deltalake/exceptions.py b/python/deltalake/exceptions.py index bacd0af9f8..a2e5b1ba1e 100644 --- a/python/deltalake/exceptions.py +++ b/python/deltalake/exceptions.py @@ -1,4 +1,5 @@ from ._internal import CommitFailedError as CommitFailedError from ._internal import DeltaError as DeltaError from ._internal import DeltaProtocolError as DeltaProtocolError +from ._internal import SchemaMismatchError as SchemaMismatchError from ._internal import TableNotFoundError as TableNotFoundError diff --git a/python/deltalake/schema.py b/python/deltalake/schema.py index 2f89cfaa2d..7124d51de3 100644 --- a/python/deltalake/schema.py +++ b/python/deltalake/schema.py @@ -51,9 +51,10 @@ def dtype_to_delta_dtype(dtype: pa.DataType) -> pa.DataType: elif isinstance(dtype, pa.StructType): return struct_to_delta_dtype(dtype) elif isinstance(dtype, pa.TimestampType): - return pa.timestamp( - "us" - ) # TODO(ion): propagate also timezone information during writeonce we can properly read TZ in delta schema + if dtype.tz is None: + return pa.timestamp("us") + else: + return pa.timestamp("us", "UTC") elif type(dtype) is pa.FixedSizeBinaryType: return pa.binary() try: diff --git a/python/deltalake/table.py b/python/deltalake/table.py index b6dd27f49d..5869ceb2e2 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -53,8 +53,13 @@ else: _has_pandas = True -MAX_SUPPORTED_READER_VERSION = 1 -MAX_SUPPORTED_WRITER_VERSION = 2 +MAX_SUPPORTED_PYARROW_WRITER_VERSION = 7 +NOT_SUPPORTED_PYARROW_WRITER_VERSIONS = [3, 4, 5, 6] +SUPPORTED_WRITER_FEATURES = {"appendOnly", "invariants", "timestampNtz"} + +MAX_SUPPORTED_READER_VERSION = 3 +NOT_SUPPORTED_READER_VERSION = 2 +SUPPORTED_READER_FEATURES = {"timestampNtz"} class Compression(Enum): @@ -242,6 +247,8 @@ def __str__(self) -> str: class ProtocolVersions(NamedTuple): min_reader_version: int min_writer_version: int + writer_features: Optional[List[str]] + reader_features: Optional[List[str]] FilterLiteralType = Tuple[str, str, Any] @@ -1025,11 +1032,26 @@ def to_pyarrow_dataset( Returns: the PyArrow dataset in PyArrow """ - if self.protocol().min_reader_version > MAX_SUPPORTED_READER_VERSION: + table_protocol = self.protocol() + if ( + table_protocol.min_reader_version > MAX_SUPPORTED_READER_VERSION + or table_protocol.min_reader_version == NOT_SUPPORTED_READER_VERSION + ): raise DeltaProtocolError( - f"The table's minimum reader version is {self.protocol().min_reader_version} " - f"but deltalake only supports up to version {MAX_SUPPORTED_READER_VERSION}." + f"The table's minimum reader version is {table_protocol.min_reader_version} " + f"but deltalake only supports version 1 or {MAX_SUPPORTED_READER_VERSION} with these reader features: {SUPPORTED_READER_FEATURES}" + ) + if ( + table_protocol.min_reader_version >= 3 + and table_protocol.reader_features is not None + ): + missing_features = {*table_protocol.reader_features}.difference( + SUPPORTED_READER_FEATURES ) + if len(missing_features) > 0: + raise DeltaProtocolError( + f"The table has set these reader features: {missing_features} but these are not yet supported by the deltalake reader." + ) if not filesystem: file_sizes = self.get_add_actions().to_pydict() @@ -1277,9 +1299,9 @@ def __init__( self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None - self.not_matched_by_source_update_predicate: Optional[ - List[Optional[str]] - ] = None + self.not_matched_by_source_update_predicate: Optional[List[Optional[str]]] = ( + None + ) self.not_matched_by_source_delete_predicate: Optional[List[str]] = None self.not_matched_by_source_delete_all: Optional[bool] = None @@ -1775,6 +1797,40 @@ def add_constraint( self.table._table.add_constraints(constraints, custom_metadata) + def drop_constraint( + self, + name: str, + raise_if_not_exists: bool = True, + custom_metadata: Optional[Dict[str, str]] = None, + ) -> None: + """ + Drop constraints from a table. Limited to `single constraint` at once. + + Args: + name: constraint name which to drop. + raise_if_not_exists: set if should raise if not exists. + custom_metadata: custom metadata that will be added to the transaction commit. + Example: + ```python + from deltalake import DeltaTable + dt = DeltaTable("test_table_constraints") + dt.metadata().configuration + {'delta.constraints.value_gt_5': 'value > 5'} + ``` + + **Drop the constraint** + ```python + dt.alter.drop_constraint(name = "value_gt_5") + ``` + + **Configuration after dropping** + ```python + dt.metadata().configuration + {} + ``` + """ + self.table._table.drop_constraints(name, raise_if_not_exists, custom_metadata) + class TableOptimizer: """API for various table optimization commands.""" diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 20eebc6fd5..4a61254a23 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -31,6 +31,8 @@ else: from typing_extensions import Literal +import warnings + import pyarrow as pa import pyarrow.dataset as ds import pyarrow.fs as pa_fs @@ -48,7 +50,13 @@ convert_pyarrow_recordbatchreader, convert_pyarrow_table, ) -from .table import MAX_SUPPORTED_WRITER_VERSION, DeltaTable, WriterProperties +from .table import ( + MAX_SUPPORTED_PYARROW_WRITER_VERSION, + NOT_SUPPORTED_PYARROW_WRITER_VERSIONS, + SUPPORTED_WRITER_FEATURES, + DeltaTable, + WriterProperties, +) try: import pandas as pd # noqa: F811 @@ -95,13 +103,13 @@ def write_deltalake( description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., overwrite_schema: bool = ..., + schema_mode: Optional[Literal["overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., partition_filters: Optional[List[Tuple[str, str, Any]]] = ..., large_dtypes: bool = ..., engine: Literal["pyarrow"] = ..., custom_metadata: Optional[Dict[str, str]] = ..., -) -> None: - ... +) -> None: ... @overload @@ -123,13 +131,13 @@ def write_deltalake( description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., overwrite_schema: bool = ..., + schema_mode: Optional[Literal["merge", "overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., large_dtypes: bool = ..., engine: Literal["rust"], writer_properties: WriterProperties = ..., custom_metadata: Optional[Dict[str, str]] = ..., -) -> None: - ... +) -> None: ... @overload @@ -151,14 +159,14 @@ def write_deltalake( description: Optional[str] = ..., configuration: Optional[Mapping[str, Optional[str]]] = ..., overwrite_schema: bool = ..., + schema_mode: Optional[Literal["merge", "overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., predicate: Optional[str] = ..., large_dtypes: bool = ..., engine: Literal["rust"], writer_properties: WriterProperties = ..., custom_metadata: Optional[Dict[str, str]] = ..., -) -> None: - ... +) -> None: ... def write_deltalake( @@ -185,6 +193,7 @@ def write_deltalake( description: Optional[str] = None, configuration: Optional[Mapping[str, Optional[str]]] = None, overwrite_schema: bool = False, + schema_mode: Optional[Literal["merge", "overwrite"]] = None, storage_options: Optional[Dict[str, str]] = None, partition_filters: Optional[List[Tuple[str, str, Any]]] = None, predicate: Optional[str] = None, @@ -238,7 +247,8 @@ def write_deltalake( name: User-provided identifier for this table. description: User-provided description for this table. configuration: A map containing configuration options for the metadata action. - overwrite_schema: If True, allows updating the schema of the table. + overwrite_schema: Deprecated, use schema_mode instead. + schema_mode: If set to "overwrite", allows replacing the schema of the table. Set to "merge" to merge with existing schema. storage_options: options passed to the native delta filesystem. predicate: When using `Overwrite` mode, replace data that matches a predicate. Only used in rust engine. partition_filters: the partition filters that will be used for partition overwrite. Only used in pyarrow engine. @@ -256,12 +266,19 @@ def write_deltalake( table.update_incremental() __enforce_append_only(table=table, configuration=configuration, mode=mode) + if overwrite_schema: + schema_mode = "overwrite" + warnings.warn( + "overwrite_schema is deprecated, use schema_mode instead. ", + category=DeprecationWarning, + stacklevel=2, + ) if isinstance(partition_by, str): partition_by = [partition_by] if isinstance(schema, DeltaSchema): - schema = schema.to_pyarrow() + schema = schema.to_pyarrow(as_large_types=True) if isinstance(data, RecordBatchReader): data = convert_pyarrow_recordbatchreader(data, large_dtypes) @@ -302,7 +319,7 @@ def write_deltalake( partition_by=partition_by, mode=mode, max_rows_per_group=max_rows_per_group, - overwrite_schema=overwrite_schema, + schema_mode=schema_mode, predicate=predicate, name=name, description=description, @@ -317,13 +334,21 @@ def write_deltalake( table.update_incremental() elif engine == "pyarrow": + if schema_mode == "merge": + raise ValueError( + "schema_mode 'merge' is not supported in pyarrow engine. Use engine=rust" + ) # We need to write against the latest table version filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) + def sort_arrow_schema(schema: pa.schema) -> pa.schema: + sorted_cols = sorted(iter(schema), key=lambda x: (x.name, str(x.type))) + return pa.schema(sorted_cols) + if table: # already exists - if schema != table.schema().to_pyarrow( - as_large_types=large_dtypes - ) and not (mode == "overwrite" and overwrite_schema): + if sort_arrow_schema(schema) != sort_arrow_schema( + table.schema().to_pyarrow(as_large_types=large_dtypes) + ) and not (mode == "overwrite" and schema_mode == "overwrite"): raise ValueError( "Schema of data does not match table schema\n" f"Data schema:\n{schema}\nTable Schema:\n{table.schema().to_pyarrow(as_large_types=large_dtypes)}" @@ -402,12 +427,29 @@ def visitor(written_file: Any) -> None: if table is not None: # We don't currently provide a way to set invariants # (and maybe never will), so only enforce if already exist. - if table.protocol().min_writer_version > MAX_SUPPORTED_WRITER_VERSION: + table_protocol = table.protocol() + if ( + table_protocol.min_writer_version > MAX_SUPPORTED_PYARROW_WRITER_VERSION + or table_protocol.min_writer_version + in NOT_SUPPORTED_PYARROW_WRITER_VERSIONS + ): raise DeltaProtocolError( "This table's min_writer_version is " - f"{table.protocol().min_writer_version}, " - "but this method only supports version 2." + f"{table_protocol.min_writer_version}, " + f"""but this method only supports version 2 or 7 with at max these features {SUPPORTED_WRITER_FEATURES} enabled. + Try engine='rust' instead which supports more features and writer versions.""" + ) + if ( + table_protocol.min_writer_version >= 7 + and table_protocol.writer_features is not None + ): + missing_features = {*table_protocol.writer_features}.difference( + SUPPORTED_WRITER_FEATURES ) + if len(missing_features) > 0: + raise DeltaProtocolError( + f"The table has set these writer features: {missing_features} but these are not supported by the pyarrow writer. Please use engine='rust'." + ) invariants = table.schema().invariants checker = _DeltaDataChecker(invariants) @@ -417,12 +459,12 @@ def check_data_is_aligned_with_partition_filtering( ) -> None: if table is None: return - existed_partitions: FrozenSet[ - FrozenSet[Tuple[str, Optional[str]]] - ] = table._table.get_active_partitions() - allowed_partitions: FrozenSet[ - FrozenSet[Tuple[str, Optional[str]]] - ] = table._table.get_active_partitions(partition_filters) + existed_partitions: FrozenSet[FrozenSet[Tuple[str, Optional[str]]]] = ( + table._table.get_active_partitions() + ) + allowed_partitions: FrozenSet[FrozenSet[Tuple[str, Optional[str]]]] = ( + table._table.get_active_partitions(partition_filters) + ) partition_values = pa.RecordBatch.from_arrays( [ batch.column(column_name) diff --git a/python/docs/source/_ext/edit_on_github.py b/python/docs/source/_ext/edit_on_github.py index f7188f189a..241560877c 100644 --- a/python/docs/source/_ext/edit_on_github.py +++ b/python/docs/source/_ext/edit_on_github.py @@ -38,9 +38,9 @@ def html_page_context(app, pagename, templatename, context, doctree): context["display_github"] = True context["github_user"] = app.config.edit_on_github_project.split("/")[0] context["github_repo"] = app.config.edit_on_github_project.split("/")[1] - context[ - "github_version" - ] = f"{app.config.edit_on_github_branch}/{app.config.page_source_prefix}/" + context["github_version"] = ( + f"{app.config.edit_on_github_branch}/{app.config.page_source_prefix}/" + ) def setup(app): diff --git a/python/docs/source/usage.rst b/python/docs/source/usage.rst index 1b6f835152..e5fbcc2e9f 100644 --- a/python/docs/source/usage.rst +++ b/python/docs/source/usage.rst @@ -481,7 +481,7 @@ to append pass in ``mode='append'``: :py:meth:`write_deltalake` will raise :py:exc:`ValueError` if the schema of the data passed to it differs from the existing table's schema. If you wish to -alter the schema as part of an overwrite pass in ``overwrite_schema=True``. +alter the schema as part of an overwrite pass in ``schema_mode="overwrite"``. .. note:: if the local file system does not support hard links, such as blobfuse, diff --git a/python/pyproject.toml b/python/pyproject.toml index e9fc7389af..9b74760948 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,8 +27,8 @@ pandas = [ "pandas" ] devel = [ - "mypy", - "ruff>=0.1.5", + "mypy~=1.8.0", + "ruff~=0.3.0", "packaging>=20", "pytest", "pytest-mock", diff --git a/python/src/error.rs b/python/src/error.rs index a69160e3ec..a54b1e60b4 100644 --- a/python/src/error.rs +++ b/python/src/error.rs @@ -10,6 +10,7 @@ create_exception!(_internal, DeltaError, PyException); create_exception!(_internal, TableNotFoundError, DeltaError); create_exception!(_internal, DeltaProtocolError, DeltaError); create_exception!(_internal, CommitFailedError, DeltaError); +create_exception!(_internal, SchemaMismatchError, DeltaError); fn inner_to_py_err(err: DeltaTableError) -> PyErr { match err { @@ -55,6 +56,7 @@ fn arrow_to_py(err: ArrowError) -> PyErr { ArrowError::DivideByZero => PyValueError::new_err("division by zero"), ArrowError::InvalidArgumentError(msg) => PyValueError::new_err(msg), ArrowError::NotYetImplemented(msg) => PyNotImplementedError::new_err(msg), + ArrowError::SchemaError(msg) => SchemaMismatchError::new_err(msg), other => PyException::new_err(other.to_string()), } } diff --git a/python/src/lib.rs b/python/src/lib.rs index 53efe489f2..0d98d6906d 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -29,6 +29,7 @@ use deltalake::kernel::{Action, Add, Invariant, LogicalFile, Remove, Scalar, Str use deltalake::operations::constraints::ConstraintBuilder; use deltalake::operations::convert_to_delta::{ConvertToDeltaBuilder, PartitionStrategy}; use deltalake::operations::delete::DeleteBuilder; +use deltalake::operations::drop_constraints::DropConstraintBuilder; use deltalake::operations::filesystem_check::FileSystemCheckBuilder; use deltalake::operations::merge::MergeBuilder; use deltalake::operations::optimize::{OptimizeBuilder, OptimizeType}; @@ -87,6 +88,8 @@ struct RawDeltaTableMetaData { configuration: HashMap>, } +type StringVec = Vec; + #[pymethods] impl RawDeltaTable { #[new] @@ -145,16 +148,35 @@ impl RawDeltaTable { }) } - pub fn protocol_versions(&self) -> PyResult<(i32, i32)> { + pub fn protocol_versions(&self) -> PyResult<(i32, i32, Option, Option)> { + let table_protocol = self._table.protocol().map_err(PythonError::from)?; Ok(( - self._table - .protocol() - .map_err(PythonError::from)? - .min_reader_version, - self._table - .protocol() - .map_err(PythonError::from)? - .min_writer_version, + table_protocol.min_reader_version, + table_protocol.min_writer_version, + table_protocol + .writer_features + .as_ref() + .and_then(|features| { + let empty_set = !features.is_empty(); + empty_set.then(|| { + features + .iter() + .map(|v| v.to_string()) + .collect::>() + }) + }), + table_protocol + .reader_features + .as_ref() + .and_then(|features| { + let empty_set = !features.is_empty(); + empty_set.then(|| { + features + .iter() + .map(|v| v.to_string()) + .collect::>() + }) + }), )) } @@ -470,6 +492,33 @@ impl RawDeltaTable { Ok(()) } + #[pyo3(signature = (name, raise_if_not_exists, custom_metadata=None))] + pub fn drop_constraints( + &mut self, + name: String, + raise_if_not_exists: bool, + custom_metadata: Option>, + ) -> PyResult<()> { + let mut cmd = DropConstraintBuilder::new( + self._table.log_store(), + self._table.snapshot().map_err(PythonError::from)?.clone(), + ) + .with_constraint(name) + .with_raise_if_not_exists(raise_if_not_exists); + + if let Some(metadata) = custom_metadata { + let json_metadata: Map = + metadata.into_iter().map(|(k, v)| (k, v.into())).collect(); + cmd = cmd.with_metadata(json_metadata); + }; + + let table = rt()? + .block_on(cmd.into_future()) + .map_err(PythonError::from)?; + self._table.state = table.state; + Ok(()) + } + #[allow(clippy::too_many_arguments)] #[pyo3(signature = (source, predicate, @@ -1114,10 +1163,13 @@ fn scalar_to_py(value: &Scalar, py_date: &PyAny, py: Python) -> PyResult val.to_object(py), Float(val) => val.to_object(py), Double(val) => val.to_object(py), - // TODO: Since PyArrow 13.0.0, casting string -> timestamp fails if it ends with "Z" - // and the target type is timezone naive. The serialization does not produce "Z", - // but we need to consider timezones when doing timezone ntz. Timestamp(_) => { + // We need to manually append 'Z' add to end so that pyarrow can cast the + // the scalar value to pa.timestamp("us","UTC") + let value = value.serialize(); + format!("{}Z", value).to_object(py) + } + TimestampNtz(_) => { let value = value.serialize(); value.to_object(py) } @@ -1335,11 +1387,12 @@ impl From<&PyAddAction> for Add { #[pyfunction] #[allow(clippy::too_many_arguments)] fn write_to_deltalake( + py: Python, table_uri: String, data: PyArrowType, mode: String, max_rows_per_group: i64, - overwrite_schema: bool, + schema_mode: Option, partition_by: Option>, predicate: Option, name: Option, @@ -1349,59 +1402,62 @@ fn write_to_deltalake( writer_properties: Option>>, custom_metadata: Option>, ) -> PyResult<()> { - let batches = data.0.map(|batch| batch.unwrap()).collect::>(); - let save_mode = mode.parse().map_err(PythonError::from)?; - - let options = storage_options.clone().unwrap_or_default(); - let table = rt()? - .block_on(DeltaOps::try_from_uri_with_storage_options( - &table_uri, options, - )) - .map_err(PythonError::from)?; + py.allow_threads(|| { + let batches = data.0.map(|batch| batch.unwrap()).collect::>(); + let save_mode = mode.parse().map_err(PythonError::from)?; - let mut builder = table - .write(batches) - .with_save_mode(save_mode) - .with_overwrite_schema(overwrite_schema) - .with_write_batch_size(max_rows_per_group as usize); + let options = storage_options.clone().unwrap_or_default(); + let table = rt()? + .block_on(DeltaOps::try_from_uri_with_storage_options( + &table_uri, options, + )) + .map_err(PythonError::from)?; - if let Some(partition_columns) = partition_by { - builder = builder.with_partition_columns(partition_columns); - } + let mut builder = table + .write(batches) + .with_save_mode(save_mode) + .with_write_batch_size(max_rows_per_group as usize); + if let Some(schema_mode) = schema_mode { + builder = builder.with_schema_mode(schema_mode.parse().map_err(PythonError::from)?); + } + if let Some(partition_columns) = partition_by { + builder = builder.with_partition_columns(partition_columns); + } - if let Some(writer_props) = writer_properties { - builder = builder.with_writer_properties( - set_writer_properties(writer_props).map_err(PythonError::from)?, - ); - } + if let Some(writer_props) = writer_properties { + builder = builder.with_writer_properties( + set_writer_properties(writer_props).map_err(PythonError::from)?, + ); + } - if let Some(name) = &name { - builder = builder.with_table_name(name); - }; + if let Some(name) = &name { + builder = builder.with_table_name(name); + }; - if let Some(description) = &description { - builder = builder.with_description(description); - }; + if let Some(description) = &description { + builder = builder.with_description(description); + }; - if let Some(predicate) = predicate { - builder = builder.with_replace_where(predicate); - }; + if let Some(predicate) = predicate { + builder = builder.with_replace_where(predicate); + }; - if let Some(config) = configuration { - builder = builder.with_configuration(config); - }; + if let Some(config) = configuration { + builder = builder.with_configuration(config); + }; - if let Some(metadata) = custom_metadata { - let json_metadata: Map = - metadata.into_iter().map(|(k, v)| (k, v.into())).collect(); - builder = builder.with_metadata(json_metadata); - }; + if let Some(metadata) = custom_metadata { + let json_metadata: Map = + metadata.into_iter().map(|(k, v)| (k, v.into())).collect(); + builder = builder.with_metadata(json_metadata); + }; - rt()? - .block_on(builder.into_future()) - .map_err(PythonError::from)?; + rt()? + .block_on(builder.into_future()) + .map_err(PythonError::from)?; - Ok(()) + Ok(()) + }) } #[pyfunction] @@ -1595,7 +1651,7 @@ impl PyDeltaDataChecker { #[pymodule] // module name need to match project name fn _internal(py: Python, m: &PyModule) -> PyResult<()> { - use crate::error::{CommitFailedError, DeltaError, TableNotFoundError}; + use crate::error::{CommitFailedError, DeltaError, SchemaMismatchError, TableNotFoundError}; deltalake::aws::register_handlers(None); deltalake::azure::register_handlers(None); @@ -1605,6 +1661,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> { m.add("CommitFailedError", py.get_type::())?; m.add("DeltaProtocolError", py.get_type::())?; m.add("TableNotFoundError", py.get_type::())?; + m.add("SchemaMismatchError", py.get_type::())?; env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("warn")).init(); m.add("__version__", env!("CARGO_PKG_VERSION"))?; diff --git a/python/tests/data_acceptance/test_reader.py b/python/tests/data_acceptance/test_reader.py index 8d3faff49e..ae77100618 100644 --- a/python/tests/data_acceptance/test_reader.py +++ b/python/tests/data_acceptance/test_reader.py @@ -46,6 +46,7 @@ class ReadCase(NamedTuple): failing_cases = { "multi_partitioned_2": "Waiting for PyArrow 11.0.0 for decimal cast support (#1078)", "multi_partitioned": "Test case handles binary poorly", + "all_primitive_types": "The parquet table written with PySpark incorrectly wrote a timestamp primitive without Timezone information.", } diff --git a/python/tests/pyspark_integration/test_write_to_pyspark.py b/python/tests/pyspark_integration/test_write_to_pyspark.py index 8418f587ca..5cf6490a62 100644 --- a/python/tests/pyspark_integration/test_write_to_pyspark.py +++ b/python/tests/pyspark_integration/test_write_to_pyspark.py @@ -1,4 +1,5 @@ """Tests that deltalake(delta-rs) can write to tables written by PySpark.""" + import pathlib import pyarrow as pa diff --git a/python/tests/pyspark_integration/test_writer_readable.py b/python/tests/pyspark_integration/test_writer_readable.py index ea555074b8..3ade57c6e9 100644 --- a/python/tests/pyspark_integration/test_writer_readable.py +++ b/python/tests/pyspark_integration/test_writer_readable.py @@ -1,4 +1,5 @@ """Test that pyspark can read tables written by deltalake(delta-rs).""" + import pathlib import pyarrow as pa diff --git a/python/tests/test_alter.py b/python/tests/test_alter.py index 7069124581..4bc902d330 100644 --- a/python/tests/test_alter.py +++ b/python/tests/test_alter.py @@ -20,6 +20,7 @@ def test_add_constraint(tmp_path: pathlib.Path, sample_table: pa.Table): assert dt.metadata().configuration == { "delta.constraints.check_price": "price >= 0" } + assert dt.protocol().min_writer_version == 3 with pytest.raises(DeltaError): # Invalid constraint @@ -60,3 +61,55 @@ def test_add_constraint_roundtrip_metadata( ) assert dt.history(1)[0]["userName"] == "John Doe" + + +def test_drop_constraint(tmp_path: pathlib.Path, sample_table: pa.Table): + write_deltalake(tmp_path, sample_table) + + dt = DeltaTable(tmp_path) + + dt.alter.add_constraint({"check_price": "price >= 0"}) + assert dt.protocol().min_writer_version == 3 + dt.alter.drop_constraint(name="check_price") + last_action = dt.history(1)[0] + assert last_action["operation"] == "DROP CONSTRAINT" + assert dt.version() == 2 + assert dt.metadata().configuration == {} + assert dt.protocol().min_writer_version == 3 + + +def test_drop_constraint_invalid(tmp_path: pathlib.Path, sample_table: pa.Table): + write_deltalake(tmp_path, sample_table) + + dt = DeltaTable(tmp_path) + + dt.alter.add_constraint({"check_price": "price >= 0"}) + with pytest.raises(DeltaError): + dt.alter.drop_constraint(name="invalid_constraint_name") + + assert dt.metadata().configuration == { + "delta.constraints.check_price": "price >= 0" + } + assert dt.protocol().min_writer_version == 3 + + +def test_drop_constraint_invalid_ignore(tmp_path: pathlib.Path, sample_table: pa.Table): + write_deltalake(tmp_path, sample_table) + + dt = DeltaTable(tmp_path) + + dt.alter.add_constraint({"check_price": "price >= 0"}) + dt.alter.drop_constraint(name="invalid_constraint_name", raise_if_not_exists=False) + + +def test_drop_constraint_roundtrip_metadata( + tmp_path: pathlib.Path, sample_table: pa.Table +): + write_deltalake(tmp_path, sample_table, mode="append", engine="rust") + + dt = DeltaTable(tmp_path) + + dt.alter.add_constraint({"check_price2": "price >= 0"}) + dt.alter.drop_constraint("check_price2", custom_metadata={"userName": "John Doe"}) + + assert dt.history(1)[0]["userName"] == "John Doe" diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index 7feb66a743..4d70c720dd 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -404,10 +404,10 @@ def test_delta_schema(): ("timestamp1", pa.timestamp("us")), ("timestamp2", pa.timestamp("us")), ("timestamp3", pa.timestamp("us")), - ("timestamp4", pa.timestamp("us")), - ("timestamp5", pa.timestamp("us")), - ("timestamp6", pa.timestamp("us")), - ("timestamp7", pa.timestamp("us")), + ("timestamp4", pa.timestamp("us", tz="UTC")), + ("timestamp5", pa.timestamp("us", tz="UTC")), + ("timestamp6", pa.timestamp("us", tz="UTC")), + ("timestamp7", pa.timestamp("us", tz="UTC")), ] ), False, diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index ce46a41a2a..6372b10e2f 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -1,5 +1,5 @@ import os -from datetime import date, datetime +from datetime import date, datetime, timezone from pathlib import Path from threading import Barrier, Thread from types import SimpleNamespace @@ -36,12 +36,20 @@ def test_read_table_with_edge_timestamps(): parquet_read_options=ParquetReadOptions(coerce_int96_timestamp_unit="ms") ) assert dataset.to_table().to_pydict() == { - "BIG_DATE": [datetime(9999, 12, 31, 0, 0, 0), datetime(9999, 12, 30, 0, 0, 0)], - "NORMAL_DATE": [datetime(2022, 1, 1, 0, 0, 0), datetime(2022, 2, 1, 0, 0, 0)], + "BIG_DATE": [ + datetime(9999, 12, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(9999, 12, 30, 0, 0, 0, tzinfo=timezone.utc), + ], + "NORMAL_DATE": [ + datetime(2022, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2022, 2, 1, 0, 0, 0, tzinfo=timezone.utc), + ], "SOME_VALUE": [1, 2], } # Can push down filters to these timestamps. - predicate = ds.field("BIG_DATE") == datetime(9999, 12, 31, 0, 0, 0) + predicate = ds.field("BIG_DATE") == datetime( + 9999, 12, 31, 0, 0, 0, tzinfo=timezone.utc + ) assert len(list(dataset.get_fragments(predicate))) == 1 @@ -491,7 +499,7 @@ def test_delta_table_with_filters(): def test_writer_fails_on_protocol(): table_path = "../crates/test/tests/data/simple_table" dt = DeltaTable(table_path) - dt.protocol = Mock(return_value=ProtocolVersions(2, 1)) + dt.protocol = Mock(return_value=ProtocolVersions(2, 1, None, None)) with pytest.raises(DeltaProtocolError): dt.to_pyarrow_dataset() with pytest.raises(DeltaProtocolError): diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 6d6065c327..adb4023d86 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -17,7 +17,12 @@ from pyarrow.lib import RecordBatchReader from deltalake import DeltaTable, Schema, write_deltalake -from deltalake.exceptions import CommitFailedError, DeltaError, DeltaProtocolError +from deltalake.exceptions import ( + CommitFailedError, + DeltaError, + DeltaProtocolError, + SchemaMismatchError, +) from deltalake.table import ProtocolVersions from deltalake.writer import try_get_table_and_table_uri @@ -124,11 +129,17 @@ def test_enforce_schema(existing_table: DeltaTable, mode: str): def test_enforce_schema_rust_writer(existing_table: DeltaTable, mode: str): bad_data = pa.table({"x": pa.array([1, 2, 3])}) - with pytest.raises(DeltaError): + with pytest.raises( + SchemaMismatchError, + match=".*Cannot cast schema, number of fields does not match.*", + ): write_deltalake(existing_table, bad_data, mode=mode, engine="rust") table_uri = existing_table._table.table_uri() - with pytest.raises(DeltaError): + with pytest.raises( + SchemaMismatchError, + match=".*Cannot cast schema, number of fields does not match.*", + ): write_deltalake(table_uri, bad_data, mode=mode, engine="rust") @@ -136,48 +147,154 @@ def test_update_schema(existing_table: DeltaTable): new_data = pa.table({"x": pa.array([1, 2, 3])}) with pytest.raises(ValueError): - write_deltalake(existing_table, new_data, mode="append", overwrite_schema=True) + write_deltalake( + existing_table, new_data, mode="append", schema_mode="overwrite" + ) - write_deltalake(existing_table, new_data, mode="overwrite", overwrite_schema=True) + write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") read_data = existing_table.to_pyarrow_table() assert new_data == read_data assert existing_table.schema().to_pyarrow() == new_data.schema -def test_update_schema_rust_writer(existing_table: DeltaTable): - new_data = pa.table({"x": pa.array([1, 2, 3])}) +def test_merge_schema(existing_table: DeltaTable): + print(existing_table._table.table_uri()) + old_table_data = existing_table.to_pyarrow_table() + new_data = pa.table( + { + "new_x": pa.array([1, 2, 3], pa.int32()), + "new_y": pa.array([1, 2, 3], pa.int32()), + } + ) + + write_deltalake( + existing_table, new_data, mode="append", schema_mode="merge", engine="rust" + ) + # adjust schema of old_table_data and new_data to match each other + + for i in range(old_table_data.num_columns): + col = old_table_data.schema.field(i) + new_data = new_data.add_column(i, col, pa.nulls(new_data.num_rows, col.type)) + + old_table_data = old_table_data.append_column( + pa.field("new_x", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32()) + ) + old_table_data = old_table_data.append_column( + pa.field("new_y", pa.int32()), pa.nulls(old_table_data.num_rows, pa.int32()) + ) + + # define sort order + read_data = existing_table.to_pyarrow_table().sort_by( + [("utf8", "ascending"), ("new_x", "ascending")] + ) + print(repr(read_data.to_pylist())) + concated = pa.concat_tables([old_table_data, new_data]) + print(repr(concated.to_pylist())) + assert read_data == concated + + write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") + + assert existing_table.schema().to_pyarrow() == new_data.schema + + +def test_overwrite_schema(existing_table: DeltaTable): + new_data_invalid = pa.table( + { + "utf8": pa.array([1235, 546, 5645]), + "new_x": pa.array([1, 2, 3], pa.int32()), + "new_y": pa.array([1, 2, 3], pa.int32()), + } + ) with pytest.raises(DeltaError): write_deltalake( existing_table, - new_data, + new_data_invalid, mode="append", - overwrite_schema=True, + schema_mode="overwrite", engine="rust", ) + + new_data = pa.table( + { + "utf8": pa.array(["bla", "bli", "blubb"]), + "new_x": pa.array([1, 2, 3], pa.int32()), + "new_y": pa.array([1, 2, 3], pa.int32()), + } + ) with pytest.raises(DeltaError): write_deltalake( existing_table, new_data, - mode="overwrite", - overwrite_schema=False, + mode="append", + schema_mode="overwrite", + engine="rust", + ) + + write_deltalake(existing_table, new_data, mode="overwrite", schema_mode="overwrite") + + assert existing_table.schema().to_pyarrow() == new_data.schema + + +def test_update_schema_rust_writer_append(existing_table: DeltaTable): + with pytest.raises( + SchemaMismatchError, match="Cannot cast schema, number of fields does not match" + ): + # It's illegal to do schema drift without correct schema_mode + write_deltalake( + existing_table, + pa.table({"x4": pa.array([1, 2, 3])}), + mode="append", + schema_mode=None, engine="rust", ) with pytest.raises(DeltaError): + write_deltalake( # schema_mode overwrite is illegal with append + existing_table, + pa.table({"x1": pa.array([1, 2, 3])}), + mode="append", + schema_mode="overwrite", + engine="rust", + ) + with pytest.raises( + SchemaMismatchError, + match="Schema error: Fail to merge schema field 'utf8' because the from data_type = Int64 does not equal Utf8", + ): write_deltalake( existing_table, - new_data, + pa.table({"utf8": pa.array([1, 2, 3])}), mode="append", - overwrite_schema=False, + schema_mode="merge", engine="rust", ) - # TODO(ion): Remove this once we add schema overwrite support + write_deltalake( + existing_table, + pa.table({"x2": pa.array([1, 2, 3])}), + mode="append", + schema_mode="merge", + engine="rust", + ) + + +def test_update_schema_rust_writer_invalid(existing_table: DeltaTable): + new_data = pa.table({"x5": pa.array([1, 2, 3])}) + with pytest.raises( + SchemaMismatchError, match="Cannot cast schema, number of fields does not match" + ): + write_deltalake( + existing_table, + new_data, + mode="overwrite", + schema_mode=None, + engine="rust", + ) + write_deltalake( existing_table, new_data, mode="overwrite", - overwrite_schema=True, + schema_mode="overwrite", engine="rust", ) @@ -611,7 +728,7 @@ def test_writer_fails_on_protocol( sample_data: pa.Table, engine: Literal["pyarrow", "rust"], ): - existing_table.protocol = Mock(return_value=ProtocolVersions(1, 3)) + existing_table.protocol = Mock(return_value=ProtocolVersions(1, 3, None, None)) with pytest.raises(DeltaProtocolError): write_deltalake(existing_table, sample_data, mode="overwrite", engine=engine) @@ -685,35 +802,58 @@ def test_writer_with_options(tmp_path: pathlib.Path): def test_try_get_table_and_table_uri(tmp_path: pathlib.Path): + def _normalize_path(t): # who does not love Windows? ;) + return t[0], t[1].replace("\\", "/") if t[1] else t[1] + data = pa.table({"vals": pa.array(["1", "2", "3"])}) table_or_uri = tmp_path / "delta_table" write_deltalake(table_or_uri, data) delta_table = DeltaTable(table_or_uri) # table_or_uri as DeltaTable - assert try_get_table_and_table_uri(delta_table, None) == ( - delta_table, - str(tmp_path / "delta_table") + "/", + assert _normalize_path( + try_get_table_and_table_uri(delta_table, None) + ) == _normalize_path( + ( + delta_table, + str(tmp_path / "delta_table") + "/", + ) ) # table_or_uri as str - assert try_get_table_and_table_uri(str(tmp_path / "delta_table"), None) == ( - delta_table, - str(tmp_path / "delta_table"), + assert _normalize_path( + try_get_table_and_table_uri(str(tmp_path / "delta_table"), None) + ) == _normalize_path( + ( + delta_table, + str(tmp_path / "delta_table"), + ) ) - assert try_get_table_and_table_uri(str(tmp_path / "str"), None) == ( - None, - str(tmp_path / "str"), + assert _normalize_path( + try_get_table_and_table_uri(str(tmp_path / "str"), None) + ) == _normalize_path( + ( + None, + str(tmp_path / "str"), + ) ) # table_or_uri as Path - assert try_get_table_and_table_uri(tmp_path / "delta_table", None) == ( - delta_table, - str(tmp_path / "delta_table"), + assert _normalize_path( + try_get_table_and_table_uri(tmp_path / "delta_table", None) + ) == _normalize_path( + ( + delta_table, + str(tmp_path / "delta_table"), + ) ) - assert try_get_table_and_table_uri(tmp_path / "Path", None) == ( - None, - str(tmp_path / "Path"), + assert _normalize_path( + try_get_table_and_table_uri(tmp_path / "Path", None) + ) == _normalize_path( + ( + None, + str(tmp_path / "Path"), + ) ) # table_or_uri with invalid parameter type @@ -1297,3 +1437,43 @@ def test_write_stats_empty_rowgroups(tmp_path: pathlib.Path): dt.to_pyarrow_dataset().to_table(filter=(pc.field("data") == "B")).shape[0] == 33792 ) + + +@pytest.mark.parametrize("engine", ["pyarrow", "rust"]) +def test_schema_cols_diff_order(tmp_path: pathlib.Path, engine): + data = pa.table( + { + "foo": pa.array(["B"] * 10), + "bar": pa.array([1] * 10), + "baz": pa.array([2.0] * 10), + } + ) + write_deltalake(tmp_path, data, mode="append", engine=engine) + + data = pa.table( + { + "baz": pa.array([2.0] * 10), + "bar": pa.array([1] * 10), + "foo": pa.array(["B"] * 10), + } + ) + write_deltalake(tmp_path, data, mode="append", engine=engine) + dt = DeltaTable(tmp_path) + assert dt.version() == 1 + + expected = pa.table( + { + "baz": pa.array([2.0] * 20), + "bar": pa.array([1] * 20), + "foo": pa.array(["B"] * 20), + } + ) + + assert dt.to_pyarrow_table(columns=["baz", "bar", "foo"]) == expected + + +def test_empty(existing_table: DeltaTable): + schema = existing_table.schema().to_pyarrow() + empty_table = pa.Table.from_pylist([], schema=schema) + with pytest.raises(DeltaError, match="No data source supplied to write command"): + write_deltalake(existing_table, empty_table, mode="append", engine="rust")