From 7e74b3861d3cbcdf43b2c1b6a342a2d5cba1a2a0 Mon Sep 17 00:00:00 2001 From: Stephen Carman Date: Fri, 29 Nov 2024 11:24:42 -0500 Subject: [PATCH 1/4] feat: rust 2024 edition, moved unity to its own crate Signed-off-by: Stephen Carman --- Cargo.toml | 2 +- crates/aws/src/lib.rs | 3 +- crates/aws/src/storage.rs | 116 +++---- crates/aws/tests/common.rs | 8 +- crates/catalog-unity/Cargo.toml | 38 +++ .../src}/client/backoff.rs | 0 .../src}/client/mock_server.rs | 0 .../src}/client/mod.rs | 4 +- .../src}/client/pagination.rs | 2 +- .../src}/client/retry.rs | 321 +++++++++--------- .../src}/client/token.rs | 0 .../unity => catalog-unity/src}/credential.rs | 216 ++++++------ .../unity => catalog-unity/src}/datafusion.rs | 14 +- crates/catalog-unity/src/error.rs | 41 +++ .../unity/mod.rs => catalog-unity/src/lib.rs} | 148 ++++---- .../unity => catalog-unity/src}/models.rs | 0 crates/core/Cargo.toml | 10 +- crates/core/src/data_catalog/mod.rs | 16 - crates/core/src/delta_datafusion/mod.rs | 2 +- crates/core/tests/command_optimize.rs | 4 +- crates/core/tests/command_restore.rs | 4 +- crates/gcp/tests/context.rs | 10 +- crates/sql/src/logical_plan.rs | 2 +- python/Cargo.toml | 4 +- 24 files changed, 523 insertions(+), 442 deletions(-) create mode 100644 crates/catalog-unity/Cargo.toml rename crates/{core/src/data_catalog => catalog-unity/src}/client/backoff.rs (100%) rename crates/{core/src/data_catalog => catalog-unity/src}/client/mock_server.rs (100%) rename crates/{core/src/data_catalog => catalog-unity/src}/client/mod.rs (99%) rename crates/{core/src/data_catalog => catalog-unity/src}/client/pagination.rs (97%) rename crates/{core/src/data_catalog => catalog-unity/src}/client/retry.rs (57%) rename crates/{core/src/data_catalog => catalog-unity/src}/client/token.rs (100%) rename crates/{core/src/data_catalog/unity => catalog-unity/src}/credential.rs (78%) rename crates/{core/src/data_catalog/unity => catalog-unity/src}/datafusion.rs (98%) create mode 100644 crates/catalog-unity/src/error.rs rename crates/{core/src/data_catalog/unity/mod.rs => catalog-unity/src/lib.rs} (88%) rename crates/{core/src/data_catalog/unity => catalog-unity/src}/models.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index e1832c2349..daebaaa2eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ resolver = "2" [workspace.package] authors = ["Qingping Hou "] -rust-version = "1.80" +rust-version = "1.85" keywords = ["deltalake", "delta", "datalake"] readme = "README.md" edition = "2021" diff --git a/crates/aws/src/lib.rs b/crates/aws/src/lib.rs index ee7f222701..147ea1bdc6 100644 --- a/crates/aws/src/lib.rs +++ b/crates/aws/src/lib.rs @@ -728,7 +728,6 @@ fn extract_version_from_filename(name: &str) -> Option { mod tests { use super::*; use aws_sdk_sts::config::ProvideCredentials; - use object_store::memory::InMemory; use serial_test::serial; @@ -771,7 +770,7 @@ mod tests { let factory = S3LogStoreFactory::default(); let store = InMemory::new(); let url = Url::parse("s3://test-bucket").unwrap(); - std::env::remove_var(crate::constants::AWS_S3_LOCKING_PROVIDER); + unsafe { std::env::remove_var(crate::constants::AWS_S3_LOCKING_PROVIDER); } let logstore = factory .with_options(Arc::new(store), &url, &StorageOptions::from(HashMap::new())) .unwrap(); diff --git a/crates/aws/src/storage.rs b/crates/aws/src/storage.rs index b2ad64d0c1..80e912a0d3 100644 --- a/crates/aws/src/storage.rs +++ b/crates/aws/src/storage.rs @@ -529,30 +529,30 @@ mod tests { fn storage_options_default_test() { ScopedEnv::run(|| { clear_env_of_aws_keys(); - - std::env::set_var(constants::AWS_ENDPOINT_URL, "http://localhost"); - std::env::set_var(constants::AWS_REGION, "us-west-1"); - std::env::set_var(constants::AWS_PROFILE, "default"); - std::env::set_var(constants::AWS_ACCESS_KEY_ID, "default_key_id"); - std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "default_secret_key"); - std::env::set_var(constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); - std::env::set_var( - constants::AWS_IAM_ROLE_ARN, - "arn:aws:iam::123456789012:role/some_role", - ); - std::env::set_var(constants::AWS_IAM_ROLE_SESSION_NAME, "session_name"); - std::env::set_var( - #[allow(deprecated)] - constants::AWS_S3_ASSUME_ROLE_ARN, - "arn:aws:iam::123456789012:role/some_role", - ); - std::env::set_var( - #[allow(deprecated)] - constants::AWS_S3_ROLE_SESSION_NAME, - "session_name", - ); - std::env::set_var(constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); - + unsafe { + std::env::set_var(constants::AWS_ENDPOINT_URL, "http://localhost"); + std::env::set_var(constants::AWS_REGION, "us-west-1"); + std::env::set_var(constants::AWS_PROFILE, "default"); + std::env::set_var(constants::AWS_ACCESS_KEY_ID, "default_key_id"); + std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "default_secret_key"); + std::env::set_var(constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); + std::env::set_var( + constants::AWS_IAM_ROLE_ARN, + "arn:aws:iam::123456789012:role/some_role", + ); + std::env::set_var(constants::AWS_IAM_ROLE_SESSION_NAME, "session_name"); + std::env::set_var( + #[allow(deprecated)] + constants::AWS_S3_ASSUME_ROLE_ARN, + "arn:aws:iam::123456789012:role/some_role", + ); + std::env::set_var( + #[allow(deprecated)] + constants::AWS_S3_ROLE_SESSION_NAME, + "session_name", + ); + std::env::set_var(constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); + } let options = S3StorageOptions::try_default().unwrap(); assert_eq!( S3StorageOptions { @@ -585,7 +585,7 @@ mod tests { fn storage_options_with_only_region_and_credentials() { ScopedEnv::run(|| { clear_env_of_aws_keys(); - std::env::remove_var(constants::AWS_ENDPOINT_URL); + unsafe { std::env::remove_var(constants::AWS_ENDPOINT_URL); } let options = S3StorageOptions::from_map(&hashmap! { constants::AWS_REGION.to_string() => "eu-west-1".to_string(), constants::AWS_ACCESS_KEY_ID.to_string() => "test".to_string(), @@ -676,26 +676,28 @@ mod tests { fn storage_options_mixed_test() { ScopedEnv::run(|| { clear_env_of_aws_keys(); - std::env::set_var(constants::AWS_ENDPOINT_URL, "http://localhost"); - std::env::set_var( - constants::AWS_ENDPOINT_URL_DYNAMODB, - "http://localhost:dynamodb", - ); - std::env::set_var(constants::AWS_REGION, "us-west-1"); - std::env::set_var(constants::AWS_PROFILE, "default"); - std::env::set_var(constants::AWS_ACCESS_KEY_ID, "wrong_key_id"); - std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "wrong_secret_key"); - std::env::set_var(constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); - std::env::set_var( - constants::AWS_IAM_ROLE_ARN, - "arn:aws:iam::123456789012:role/some_role", - ); - std::env::set_var(constants::AWS_IAM_ROLE_SESSION_NAME, "session_name"); - std::env::set_var(constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); - - std::env::set_var(constants::AWS_S3_POOL_IDLE_TIMEOUT_SECONDS, "1"); - std::env::set_var(constants::AWS_STS_POOL_IDLE_TIMEOUT_SECONDS, "2"); - std::env::set_var(constants::AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES, "3"); + unsafe { + std::env::set_var(constants::AWS_ENDPOINT_URL, "http://localhost"); + std::env::set_var( + constants::AWS_ENDPOINT_URL_DYNAMODB, + "http://localhost:dynamodb", + ); + std::env::set_var(constants::AWS_REGION, "us-west-1"); + std::env::set_var(constants::AWS_PROFILE, "default"); + std::env::set_var(constants::AWS_ACCESS_KEY_ID, "wrong_key_id"); + std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "wrong_secret_key"); + std::env::set_var(constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); + std::env::set_var( + constants::AWS_IAM_ROLE_ARN, + "arn:aws:iam::123456789012:role/some_role", + ); + std::env::set_var(constants::AWS_IAM_ROLE_SESSION_NAME, "session_name"); + std::env::set_var(constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); + + std::env::set_var(constants::AWS_S3_POOL_IDLE_TIMEOUT_SECONDS, "1"); + std::env::set_var(constants::AWS_STS_POOL_IDLE_TIMEOUT_SECONDS, "2"); + std::env::set_var(constants::AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES, "3"); + } let options = S3StorageOptions::from_map(&hashmap! { constants::AWS_ACCESS_KEY_ID.to_string() => "test_id_mixed".to_string(), constants::AWS_SECRET_ACCESS_KEY.to_string() => "test_secret_mixed".to_string(), @@ -767,12 +769,12 @@ mod tests { ScopedEnv::run(|| { clear_env_of_aws_keys(); let raw_options = hashmap! {}; - - std::env::set_var(constants::AWS_ACCESS_KEY_ID, "env_key"); - std::env::set_var(constants::AWS_ENDPOINT_URL, "env_key"); - std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "env_key"); - std::env::set_var(constants::AWS_REGION, "env_key"); - + unsafe { + std::env::set_var(constants::AWS_ACCESS_KEY_ID, "env_key"); + std::env::set_var(constants::AWS_ENDPOINT_URL, "env_key"); + std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "env_key"); + std::env::set_var(constants::AWS_REGION, "env_key"); + } let combined_options = S3ObjectStoreFactory {}.with_env_s3(&StorageOptions(raw_options)); @@ -795,12 +797,12 @@ mod tests { "AWS_SECRET_ACCESS_KEY".to_string() => "options_key".to_string(), "AWS_REGION".to_string() => "options_key".to_string() }; - - std::env::set_var("aws_access_key_id", "env_key"); - std::env::set_var("aws_endpoint", "env_key"); - std::env::set_var("aws_secret_access_key", "env_key"); - std::env::set_var("aws_region", "env_key"); - + unsafe { + std::env::set_var("aws_access_key_id", "env_key"); + std::env::set_var("aws_endpoint", "env_key"); + std::env::set_var("aws_secret_access_key", "env_key"); + std::env::set_var("aws_region", "env_key"); + } let combined_options = S3ObjectStoreFactory {}.with_env_s3(&StorageOptions(raw_options)); diff --git a/crates/aws/tests/common.rs b/crates/aws/tests/common.rs index dfa2a9cd51..1d64d79b30 100644 --- a/crates/aws/tests/common.rs +++ b/crates/aws/tests/common.rs @@ -3,7 +3,7 @@ use deltalake_aws::constants; use deltalake_aws::register_handlers; use deltalake_aws::storage::*; use deltalake_test::utils::*; -use rand::Rng; +use rand::{random, Rng}; use std::process::{Command, ExitStatus, Stdio}; #[derive(Clone, Debug)] @@ -43,14 +43,14 @@ impl StorageIntegration for S3Integration { fn prepare_env(&self) { set_env_if_not_set( constants::LOCK_TABLE_KEY_NAME, - format!("delta_log_it_{}", rand::thread_rng().gen::()), + format!("delta_log_it_{}", random::()), ); match std::env::var(s3_constants::AWS_ENDPOINT_URL).ok() { Some(endpoint_url) if endpoint_url.to_lowercase() == "none" => { - std::env::remove_var(s3_constants::AWS_ENDPOINT_URL) + unsafe { std::env::remove_var(s3_constants::AWS_ENDPOINT_URL) } } Some(_) => (), - None => std::env::set_var(s3_constants::AWS_ENDPOINT_URL, "http://localhost:4566"), + None => unsafe { std::env::set_var(s3_constants::AWS_ENDPOINT_URL, "http://localhost:4566") }, } set_env_if_not_set(s3_constants::AWS_ACCESS_KEY_ID, "deltalake"); set_env_if_not_set(s3_constants::AWS_SECRET_ACCESS_KEY, "weloverust"); diff --git a/crates/catalog-unity/Cargo.toml b/crates/catalog-unity/Cargo.toml new file mode 100644 index 0000000000..051dcb05e1 --- /dev/null +++ b/crates/catalog-unity/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "deltalake-catalog-unity" +version = "0.6.0" +authors.workspace = true +keywords.workspace = true +readme.workspace = true +edition.workspace = true +homepage.workspace = true +description.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +async-trait = { workspace = true } +deltalake-core = { version = "0.22", path = "../core", features = ["unity-experimental"] } +thiserror = { workspace = true } +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json", "http2"] } +rand = "0.8" +futures = "0.3" +chrono = "0.4" +tokio.workspace = true +serde.workspace = true +serde_json.workspace = true +dashmap = "6" +tracing = "0.1" +datafusion = { version = "43", optional = true } +datafusion-common = { version = "43", optional = true } + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +tempfile = "3" +httpmock = { version = "0.8.0-alpha.1", features = [] } + +[features] +default = [] +datafusion = ["dep:datafusion", "datafusion-common"] + diff --git a/crates/core/src/data_catalog/client/backoff.rs b/crates/catalog-unity/src/client/backoff.rs similarity index 100% rename from crates/core/src/data_catalog/client/backoff.rs rename to crates/catalog-unity/src/client/backoff.rs diff --git a/crates/core/src/data_catalog/client/mock_server.rs b/crates/catalog-unity/src/client/mock_server.rs similarity index 100% rename from crates/core/src/data_catalog/client/mock_server.rs rename to crates/catalog-unity/src/client/mock_server.rs diff --git a/crates/core/src/data_catalog/client/mod.rs b/crates/catalog-unity/src/client/mod.rs similarity index 99% rename from crates/core/src/data_catalog/client/mod.rs rename to crates/catalog-unity/src/client/mod.rs index c6cd838076..5f4d981491 100644 --- a/crates/core/src/data_catalog/client/mod.rs +++ b/crates/catalog-unity/src/client/mod.rs @@ -1,8 +1,8 @@ //! Generic utilities reqwest based Catalog implementations pub mod backoff; -#[cfg(test)] -pub mod mock_server; +// #[cfg(test)] +// pub mod mock_server; #[allow(unused)] pub mod pagination; pub mod retry; diff --git a/crates/core/src/data_catalog/client/pagination.rs b/crates/catalog-unity/src/client/pagination.rs similarity index 97% rename from crates/core/src/data_catalog/client/pagination.rs rename to crates/catalog-unity/src/client/pagination.rs index a5225237b4..630ef2aace 100644 --- a/crates/core/src/data_catalog/client/pagination.rs +++ b/crates/catalog-unity/src/client/pagination.rs @@ -3,7 +3,7 @@ use std::future::Future; use futures::Stream; -use crate::data_catalog::DataCatalogResult; +use deltalake_core::data_catalog::DataCatalogResult; /// Takes a paginated operation `op` that when called with: /// diff --git a/crates/core/src/data_catalog/client/retry.rs b/crates/catalog-unity/src/client/retry.rs similarity index 57% rename from crates/core/src/data_catalog/client/retry.rs rename to crates/catalog-unity/src/client/retry.rs index 300e7afe7b..b770080bcd 100644 --- a/crates/core/src/data_catalog/client/retry.rs +++ b/crates/catalog-unity/src/client/retry.rs @@ -7,7 +7,7 @@ use reqwest::header::LOCATION; use reqwest::{Response, StatusCode}; use std::time::{Duration, Instant}; use tracing::info; - +use deltalake_core::DataCatalogError; use super::backoff::{Backoff, BackoffConfig}; /// Retry request error @@ -61,12 +61,27 @@ impl From for std::io::Error { } } +impl From for reqwest::Error { + fn from(value: RetryError) -> Self { + Into::into(value) + } +} + +impl From for DataCatalogError { + fn from(value: RetryError) -> Self { + DataCatalogError::Generic { + catalog: "", + source: Box::new(value), + } + } +} + /// Error retrying http requests pub type Result = std::result::Result; /// Contains the configuration for how to respond to server errors /// -/// By default they will be retried up to some limit, using exponential +/// By default, they will be retried up to some limit, using exponential /// backoff with jitter. See [`BackoffConfig`] for more information /// #[derive(Debug, Clone)] @@ -177,8 +192,8 @@ impl RetryExt for reqwest::RequestBuilder { { let mut do_retry = false; if let Some(source) = e.source() { - if let Some(e) = source.downcast_ref::() { - if e.is_connect() || e.is_closed() || e.is_incomplete_message() { + if let Some(e) = source.downcast_ref::() { + if e.is_timeout() || e.is_request() || e.is_connect() { do_retry = true; } } @@ -208,158 +223,158 @@ impl RetryExt for reqwest::RequestBuilder { #[cfg(test)] mod tests { - use super::super::mock_server::MockServer; + // use super::super::mock_server::MockServer; use super::RetryConfig; use super::RetryExt; - use hyper::header::LOCATION; - use hyper::{Body, Response}; + // use hyper::header::LOCATION; + // use hyper::{ Response}; use reqwest::{Client, Method, StatusCode}; use std::time::Duration; - #[tokio::test] - async fn test_retry() { - let mock = MockServer::new(); - - let retry = RetryConfig { - backoff: Default::default(), - max_retries: 2, - retry_timeout: Duration::from_secs(1000), - }; - - let client = Client::new(); - let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry); - - // Simple request should work - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::OK); - - // Returns client errors immediately with status message - mock.push( - Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from("cupcakes")) - .unwrap(), - ); - - let e = do_request().await.unwrap_err(); - assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); - assert_eq!(e.retries, 0); - assert_eq!(&e.message, "cupcakes"); - - // Handles client errors with no payload - mock.push( - Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::empty()) - .unwrap(), - ); - - let e = do_request().await.unwrap_err(); - assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); - assert_eq!(e.retries, 0); - assert_eq!(&e.message, "No Body"); - - // Should retry server error request - mock.push( - Response::builder() - .status(StatusCode::BAD_GATEWAY) - .body(Body::empty()) - .unwrap(), - ); - - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::OK); - - // Accepts 204 status code - mock.push( - Response::builder() - .status(StatusCode::NO_CONTENT) - .body(Body::empty()) - .unwrap(), - ); - - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::NO_CONTENT); - - // Follows 402 redirects - mock.push( - Response::builder() - .status(StatusCode::FOUND) - .header(LOCATION, "/foo") - .body(Body::empty()) - .unwrap(), - ); - - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::OK); - assert_eq!(r.url().path(), "/foo"); - - // Follows 401 redirects - mock.push( - Response::builder() - .status(StatusCode::FOUND) - .header(LOCATION, "/bar") - .body(Body::empty()) - .unwrap(), - ); - - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::OK); - assert_eq!(r.url().path(), "/bar"); - - // Handles redirect loop - for _ in 0..10 { - mock.push( - Response::builder() - .status(StatusCode::FOUND) - .header(LOCATION, "/bar") - .body(Body::empty()) - .unwrap(), - ); - } - - let e = do_request().await.unwrap_err().to_string(); - assert!(e.ends_with("too many redirects"), "{}", e); - - // Handles redirect missing location - mock.push( - Response::builder() - .status(StatusCode::FOUND) - .body(Body::empty()) - .unwrap(), - ); - - let e = do_request().await.unwrap_err(); - assert_eq!(e.message, "Received redirect without LOCATION, this normally indicates an incorrectly configured region"); - - // Gives up after the retrying the specified number of times - for _ in 0..=retry.max_retries { - mock.push( - Response::builder() - .status(StatusCode::BAD_GATEWAY) - .body(Body::from("ignored")) - .unwrap(), - ); - } - - let e = do_request().await.unwrap_err(); - assert_eq!(e.retries, retry.max_retries); - assert_eq!(e.message, "502 Bad Gateway"); - - // Panic results in an incomplete message error in the client - mock.push_fn(|_| panic!()); - let r = do_request().await.unwrap(); - assert_eq!(r.status(), StatusCode::OK); - - // Gives up after retrying mulitiple panics - for _ in 0..=retry.max_retries { - mock.push_fn(|_| panic!()); - } - let e = do_request().await.unwrap_err(); - assert_eq!(e.retries, retry.max_retries); - assert_eq!(e.message, "request error"); - - // Shutdown - mock.shutdown().await - } + // #[tokio::test] + // async fn test_retry() { + // let mock = MockServer::new(); + // + // let retry = RetryConfig { + // backoff: Default::default(), + // max_retries: 2, + // retry_timeout: Duration::from_secs(1000), + // }; + // + // let client = Client::new(); + // let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry); + // + // // Simple request should work + // let r = do_request().await.unwrap(); + // assert_eq!(r.status(), StatusCode::OK); + // + // // Returns client errors immediately with status message + // mock.push( + // Response::builder() + // .status(StatusCode::BAD_REQUEST) + // .body(Body::from("cupcakes")) + // .unwrap(), + // ); + // + // let e = do_request().await.unwrap_err(); + // assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); + // assert_eq!(e.retries, 0); + // assert_eq!(&e.message, "cupcakes"); + // + // // Handles client errors with no payload + // mock.push( + // Response::builder() + // .status(StatusCode::BAD_REQUEST) + // .body(Body::empty()) + // .unwrap(), + // ); + // + // let e = do_request().await.unwrap_err(); + // assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); + // assert_eq!(e.retries, 0); + // assert_eq!(&e.message, "No Body"); + // + // // Should retry server error request + // mock.push( + // Response::builder() + // .status(StatusCode::BAD_GATEWAY) + // .body(Body::empty()) + // .unwrap(), + // ); + // + // let r = do_request().await.unwrap(); + // assert_eq!(r.status(), StatusCode::OK); + // + // // Accepts 204 status code + // mock.push( + // Response::builder() + // .status(StatusCode::NO_CONTENT) + // .body(Body::empty()) + // .unwrap(), + // ); + // + // let r = do_request().await.unwrap(); + // assert_eq!(r.status(), StatusCode::NO_CONTENT); + // + // // Follows 402 redirects + // mock.push( + // Response::builder() + // .status(StatusCode::FOUND) + // .header(LOCATION, "/foo") + // .body(Body::empty()) + // .unwrap(), + // ); + // + // let r = do_request().await.unwrap(); + // assert_eq!(r.status(), StatusCode::OK); + // assert_eq!(r.url().path(), "/foo"); + // + // // Follows 401 redirects + // mock.push( + // Response::builder() + // .status(StatusCode::FOUND) + // .header(LOCATION, "/bar") + // .body(Body::empty()) + // .unwrap(), + // ); + // + // let r = do_request().await.unwrap(); + // assert_eq!(r.status(), StatusCode::OK); + // assert_eq!(r.url().path(), "/bar"); + // + // // Handles redirect loop + // for _ in 0..10 { + // mock.push( + // Response::builder() + // .status(StatusCode::FOUND) + // .header(LOCATION, "/bar") + // .body(Body::empty()) + // .unwrap(), + // ); + // } + // + // let e = do_request().await.unwrap_err().to_string(); + // assert!(e.ends_with("too many redirects"), "{}", e); + // + // // Handles redirect missing location + // mock.push( + // Response::builder() + // .status(StatusCode::FOUND) + // .body(Body::empty()) + // .unwrap(), + // ); + // + // let e = do_request().await.unwrap_err(); + // assert_eq!(e.message, "Received redirect without LOCATION, this normally indicates an incorrectly configured region"); + // + // // Gives up after the retrying the specified number of times + // for _ in 0..=retry.max_retries { + // mock.push( + // Response::builder() + // .status(StatusCode::BAD_GATEWAY) + // .body(Body::from("ignored")) + // .unwrap(), + // ); + // } + // + // let e = do_request().await.unwrap_err(); + // assert_eq!(e.retries, retry.max_retries); + // assert_eq!(e.message, "502 Bad Gateway"); + // + // // Panic results in an incomplete message error in the client + // mock.push_fn(|_| panic!()); + // let r = do_request().await.unwrap(); + // assert_eq!(r.status(), StatusCode::OK); + // + // // Gives up after retrying mulitiple panics + // for _ in 0..=retry.max_retries { + // mock.push_fn(|_| panic!()); + // } + // let e = do_request().await.unwrap_err(); + // assert_eq!(e.retries, retry.max_retries); + // assert_eq!(e.message, "request error"); + // + // // Shutdown + // mock.shutdown().await + // } } diff --git a/crates/core/src/data_catalog/client/token.rs b/crates/catalog-unity/src/client/token.rs similarity index 100% rename from crates/core/src/data_catalog/client/token.rs rename to crates/catalog-unity/src/client/token.rs diff --git a/crates/core/src/data_catalog/unity/credential.rs b/crates/catalog-unity/src/credential.rs similarity index 78% rename from crates/core/src/data_catalog/unity/credential.rs rename to crates/catalog-unity/src/credential.rs index e0a833f182..8238b9f76c 100644 --- a/crates/core/src/data_catalog/unity/credential.rs +++ b/crates/catalog-unity/src/credential.rs @@ -8,9 +8,9 @@ use reqwest::{Client, Method}; use serde::Deserialize; use super::UnityCatalogError; -use crate::data_catalog::client::retry::{RetryConfig, RetryExt}; -use crate::data_catalog::client::token::{TemporaryToken, TokenCache}; -use crate::data_catalog::DataCatalogResult; +use crate::client::retry::{RetryConfig, RetryExt}; +use crate::client::token::{TemporaryToken, TokenCache}; +use crate::DataCatalogResult; // https://learn.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/authentication @@ -415,113 +415,113 @@ impl TokenCredential for ImdsManagedIdentityOAuthProvider { #[cfg(test)] mod tests { use super::*; - use crate::data_catalog::client::mock_server::MockServer; + // use crate::client::mock_server::MockServer; use futures::executor::block_on; - use hyper::body::to_bytes; - use hyper::{Body, Response}; + + // use hyper::{ Response}; use reqwest::{Client, Method}; use tempfile::NamedTempFile; - #[tokio::test] - async fn test_managed_identity() { - let server = MockServer::new(); - - std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret"); - - let endpoint = server.url(); - let client = Client::new(); - let retry_config = RetryConfig::default(); - - // Test IMDS - server.push_fn(|req| { - assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token"); - assert!(req.uri().query().unwrap().contains("client_id=client_id")); - assert_eq!(req.method(), &Method::GET); - let t = req - .headers() - .get("x-identity-header") - .unwrap() - .to_str() - .unwrap(); - assert_eq!(t, "env-secret"); - let t = req.headers().get("metadata").unwrap().to_str().unwrap(); - assert_eq!(t, "true"); - Response::new(Body::from( - r#" - { - "access_token": "TOKEN", - "refresh_token": "", - "expires_in": "3599", - "expires_on": "1506484173", - "not_before": "1506480273", - "resource": "https://management.azure.com/", - "token_type": "Bearer" - } - "#, - )) - }); - - let credential = ImdsManagedIdentityOAuthProvider::new( - Some("client_id".into()), - None, - None, - Some(format!("{endpoint}/metadata/identity/oauth2/token")), - client.clone(), - ); - - let token = credential - .fetch_token(&client, &retry_config) - .await - .unwrap(); - - assert_eq!(&token.token, "TOKEN"); - } - - #[tokio::test] - async fn test_workload_identity() { - let server = MockServer::new(); - let tokenfile = NamedTempFile::new().unwrap(); - let tenant = "tenant"; - std::fs::write(tokenfile.path(), "federated-token").unwrap(); - - let endpoint = server.url(); - let client = Client::new(); - let retry_config = RetryConfig::default(); - - // Test IMDS - server.push_fn(move |req| { - assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token")); - assert_eq!(req.method(), &Method::POST); - let body = block_on(to_bytes(req.into_body())).unwrap(); - let body = String::from_utf8(body.to_vec()).unwrap(); - assert!(body.contains("federated-token")); - Response::new(Body::from( - r#" - { - "access_token": "TOKEN", - "refresh_token": "", - "expires_in": 3599, - "expires_on": "1506484173", - "not_before": "1506480273", - "resource": "https://management.azure.com/", - "token_type": "Bearer" - } - "#, - )) - }); - - let credential = WorkloadIdentityOAuthProvider::new( - "client_id", - tokenfile.path().to_str().unwrap(), - tenant, - Some(endpoint.to_string()), - ); - - let token = credential - .fetch_token(&client, &retry_config) - .await - .unwrap(); - - assert_eq!(&token.token, "TOKEN"); - } + // #[tokio::test] + // async fn test_managed_identity() { + // let server = MockServer::new(); + // + // std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret"); + // + // let endpoint = server.url(); + // let client = Client::new(); + // let retry_config = RetryConfig::default(); + // + // // Test IMDS + // server.push_fn(|req| { + // assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token"); + // assert!(req.uri().query().unwrap().contains("client_id=client_id")); + // assert_eq!(req.method(), &Method::GET); + // let t = req + // .headers() + // .get("x-identity-header") + // .unwrap() + // .to_str() + // .unwrap(); + // assert_eq!(t, "env-secret"); + // let t = req.headers().get("metadata").unwrap().to_str().unwrap(); + // assert_eq!(t, "true"); + // Response::new(Body::from( + // r#" + // { + // "access_token": "TOKEN", + // "refresh_token": "", + // "expires_in": "3599", + // "expires_on": "1506484173", + // "not_before": "1506480273", + // "resource": "https://management.azure.com/", + // "token_type": "Bearer" + // } + // "#, + // )) + // }); + // + // let credential = ImdsManagedIdentityOAuthProvider::new( + // Some("client_id".into()), + // None, + // None, + // Some(format!("{endpoint}/metadata/identity/oauth2/token")), + // client.clone(), + // ); + // + // let token = credential + // .fetch_token(&client, &retry_config) + // .await + // .unwrap(); + // + // assert_eq!(&token.token, "TOKEN"); + // } + // + // #[tokio::test] + // async fn test_workload_identity() { + // let server = MockServer::new(); + // let tokenfile = NamedTempFile::new().unwrap(); + // let tenant = "tenant"; + // std::fs::write(tokenfile.path(), "federated-token").unwrap(); + // + // let endpoint = server.url(); + // let client = Client::new(); + // let retry_config = RetryConfig::default(); + // + // // Test IMDS + // server.push_fn(move |req| { + // assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token")); + // assert_eq!(req.method(), &Method::POST); + // let body = block_on(to_bytes(req.into_body())).unwrap(); + // let body = String::from_utf8(body.to_vec()).unwrap(); + // assert!(body.contains("federated-token")); + // Response::new(Body::from( + // r#" + // { + // "access_token": "TOKEN", + // "refresh_token": "", + // "expires_in": 3599, + // "expires_on": "1506484173", + // "not_before": "1506480273", + // "resource": "https://management.azure.com/", + // "token_type": "Bearer" + // } + // "#, + // )) + // }); + // + // let credential = WorkloadIdentityOAuthProvider::new( + // "client_id", + // tokenfile.path().to_str().unwrap(), + // tenant, + // Some(endpoint.to_string()), + // ); + // + // let token = credential + // .fetch_token(&client, &retry_config) + // .await + // .unwrap(); + // + // assert_eq!(&token.token, "TOKEN"); + // } } diff --git a/crates/core/src/data_catalog/unity/datafusion.rs b/crates/catalog-unity/src/datafusion.rs similarity index 98% rename from crates/core/src/data_catalog/unity/datafusion.rs rename to crates/catalog-unity/src/datafusion.rs index 3e32a3ad68..b77409bf04 100644 --- a/crates/core/src/data_catalog/unity/datafusion.rs +++ b/crates/catalog-unity/src/datafusion.rs @@ -11,10 +11,10 @@ use datafusion::datasource::TableProvider; use datafusion_common::DataFusionError; use tracing::error; -use super::models::{GetTableResponse, ListCatalogsResponse, ListTableSummariesResponse}; +use super::models::{GetTableResponse, ListCatalogsResponse, ListSchemasResponse, ListTableSummariesResponse}; use super::{DataCatalogResult, UnityCatalog}; -use crate::data_catalog::models::ListSchemasResponse; -use crate::DeltaTableBuilder; + +use deltalake_core::DeltaTableBuilder; /// In-memory list of catalogs populated by unity catalog #[derive(Debug)] @@ -56,10 +56,6 @@ impl CatalogProviderList for UnityCatalogList { self } - fn catalog_names(&self) -> Vec { - self.catalogs.iter().map(|c| c.key().clone()).collect() - } - fn register_catalog( &self, name: String, @@ -68,6 +64,10 @@ impl CatalogProviderList for UnityCatalogList { self.catalogs.insert(name, catalog) } + fn catalog_names(&self) -> Vec { + self.catalogs.iter().map(|c| c.key().clone()).collect() + } + fn catalog(&self, name: &str) -> Option> { self.catalogs.get(name).map(|c| c.value().clone()) } diff --git a/crates/catalog-unity/src/error.rs b/crates/catalog-unity/src/error.rs new file mode 100644 index 0000000000..67610c07fe --- /dev/null +++ b/crates/catalog-unity/src/error.rs @@ -0,0 +1,41 @@ +#[derive(thiserror::Error, Debug)] +pub enum UnityCatalogError { + /// A generic error qualified in the message + #[error("Error in {catalog} catalog: {source}")] + Generic { + /// Name of the catalog + catalog: &'static str, + /// Error message + source: Box, + }, + + /// A generic error qualified in the message + + #[error("{source}")] + Retry { + /// Error message + #[from] + source: crate::client::retry::RetryError, + }, + + #[error("Request error: {source}")] + + /// Error from reqwest library + RequestError { + /// The underlying reqwest_middleware::Error + #[from] + source: reqwest::Error, + }, + + /// Error caused by missing environment variable for Unity Catalog. + #[error("Missing Unity Catalog environment variable: {var_name}")] + MissingEnvVar { + /// Variable name + var_name: String, + }, + + /// Error caused by invalid access token value + + #[error("Invalid Databricks personal access token")] + InvalidAccessToken, +} diff --git a/crates/core/src/data_catalog/unity/mod.rs b/crates/catalog-unity/src/lib.rs similarity index 88% rename from crates/core/src/data_catalog/unity/mod.rs rename to crates/catalog-unity/src/lib.rs index e9de725923..00b02f9f12 100644 --- a/crates/core/src/data_catalog/unity/mod.rs +++ b/crates/catalog-unity/src/lib.rs @@ -5,19 +5,24 @@ use std::str::FromStr; use reqwest::header::{HeaderValue, AUTHORIZATION}; -use self::credential::{AzureCliCredential, ClientSecretOAuthProvider, CredentialProvider}; -use self::models::{ +use crate::credential::{AzureCliCredential, ClientSecretOAuthProvider, CredentialProvider}; +use crate::models::{ GetSchemaResponse, GetTableResponse, ListCatalogsResponse, ListSchemasResponse, ListTableSummariesResponse, }; -use super::client::retry::RetryExt; -use super::{client::retry::RetryConfig, DataCatalog, DataCatalogError, DataCatalogResult}; -use crate::storage::str_is_truthy; +use deltalake_core::data_catalog::DataCatalogResult; +use deltalake_core::{DataCatalog, DataCatalogError}; + +use crate::client::retry::*; +use deltalake_core::storage::str_is_truthy; + +pub mod client; pub mod credential; #[cfg(feature = "datafusion")] pub mod datafusion; pub mod models; +pub mod error; /// Possible errors from the unity-catalog/tables API call #[derive(thiserror::Error, Debug)] @@ -242,7 +247,7 @@ impl AsRef for UnityCatalogConfigKey { } } -/// Builder for crateing a UnityCatalogClient +/// Builder for creating a UnityCatalogClient #[derive(Default)] pub struct UnityCatalogBuilder { /// Url of a Databricks workspace @@ -282,7 +287,7 @@ pub struct UnityCatalogBuilder { retry_config: RetryConfig, /// Options for the underlying http client - client_options: super::client::ClientOptions, + client_options: client::ClientOptions, } #[allow(deprecated)] @@ -319,7 +324,7 @@ impl UnityCatalogBuilder { } /// Hydrate builder from key value pairs - pub fn try_with_options, impl Into)>>( + pub fn try_with_options, impl Into)>>( mut self, options: I, ) -> DataCatalogResult { @@ -385,7 +390,7 @@ impl UnityCatalogBuilder { } /// Sets the client options, overriding any already set - pub fn with_client_options(mut self, options: super::client::ClientOptions) -> Self { + pub fn with_client_options(mut self, options: client::ClientOptions) -> Self { self.client_options = options; self } @@ -466,7 +471,7 @@ impl UnityCatalog { // we do the conversion to a HeaderValue here, since it is fallible // and we want to use it in an infallible function HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| { - super::DataCatalogError::Generic { + DataCatalogError::Generic { catalog: "Unity", source: Box::new(err), } @@ -480,7 +485,7 @@ impl UnityCatalog { // we do the conversion to a HeaderValue here, since it is fallible // and we want to use it in an infallible function HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| { - super::DataCatalogError::Generic { + DataCatalogError::Generic { catalog: "Unity", source: Box::new(err), } @@ -529,6 +534,7 @@ impl UnityCatalog { .get(format!("{}/schemas", self.catalog_url())) .header(AUTHORIZATION, token) .query(&[("catalog_name", catalog_name.as_ref())]) + .send_retry(&self.retry_config) .await?; Ok(resp.json().await?) @@ -645,7 +651,7 @@ impl DataCatalog for UnityCatalog { error_code: err.error_code, message: err.message, } - .into()), + .into()), } } } @@ -658,66 +664,60 @@ impl std::fmt::Debug for UnityCatalog { #[cfg(test)] mod tests { - use crate::data_catalog::client::ClientOptions; - - use super::super::client::mock_server::MockServer; - use super::models::tests::{GET_SCHEMA_RESPONSE, GET_TABLE_RESPONSE, LIST_SCHEMAS_RESPONSE}; - use super::*; - use hyper::{Body, Response}; - use reqwest::Method; - - #[tokio::test] - async fn test_unity_client() { - let server = MockServer::new(); - - let options = ClientOptions::default().with_allow_http(true); - let client = UnityCatalogBuilder::new() - .with_workspace_url(server.url()) - .with_bearer_token("bearer_token") - .with_client_options(options) - .build() - .unwrap(); - - server.push_fn(move |req| { - assert_eq!(req.uri().path(), "/api/2.1/unity-catalog/schemas"); - assert_eq!(req.method(), &Method::GET); - Response::new(Body::from(LIST_SCHEMAS_RESPONSE)) - }); - - let list_schemas_response = client.list_schemas("catalog_name").await.unwrap(); - assert!(matches!( - list_schemas_response, - ListSchemasResponse::Success { .. } - )); - - server.push_fn(move |req| { - assert_eq!( - req.uri().path(), - "/api/2.1/unity-catalog/schemas/catalog_name.schema_name" - ); - assert_eq!(req.method(), &Method::GET); - Response::new(Body::from(GET_SCHEMA_RESPONSE)) - }); - - let get_schema_response = client - .get_schema("catalog_name", "schema_name") - .await - .unwrap(); - assert!(matches!(get_schema_response, GetSchemaResponse::Success(_))); - - server.push_fn(move |req| { - assert_eq!( - req.uri().path(), - "/api/2.1/unity-catalog/tables/catalog_name.schema_name.table_name" - ); - assert_eq!(req.method(), &Method::GET); - Response::new(Body::from(GET_TABLE_RESPONSE)) - }); - - let get_table_response = client - .get_table("catalog_name", "schema_name", "table_name") - .await - .unwrap(); - assert!(matches!(get_table_response, GetTableResponse::Success(_))); - } + use httpmock::prelude::*; + + // #[tokio::test] + // async fn test_unity_client() { + // let server = MockServer::new(); + // + // let options = ClientOptions::default().with_allow_http(true); + // let client = UnityCatalogBuilder::new() + // .with_workspace_url(server.url()) + // .with_bearer_token("bearer_token") + // .with_client_options(options) + // .build() + // .unwrap(); + // + // server.push_fn(move |req| { + // assert_eq!(req.uri().path(), "/api/2.1/unity-catalog/schemas"); + // assert_eq!(req.method(), Method::GET); + // Response::new(Body::from(LIST_SCHEMAS_RESPONSE)) + // }); + // + // let list_schemas_response = client.list_schemas("catalog_name").await.unwrap(); + // assert!(matches!( + // list_schemas_response, + // ListSchemasResponse::Success { .. } + // )); + // + // server.push_fn(move |req| { + // assert_eq!( + // req.uri().path(), + // "/api/2.1/unity-catalog/schemas/catalog_name.schema_name" + // ); + // assert_eq!(req.method(), &Method::GET); + // Response::new(Body::from(GET_SCHEMA_RESPONSE)) + // }); + // + // let get_schema_response = client + // .get_schema("catalog_name", "schema_name") + // .await + // .unwrap(); + // assert!(matches!(get_schema_response, GetSchemaResponse::Success(_))); + // + // server.push_fn(move |req| { + // assert_eq!( + // req.uri().path(), + // "/api/2.1/unity-catalog/tables/catalog_name.schema_name.table_name" + // ); + // assert_eq!(req.method(), &Method::GET); + // Response::new(Body::from(GET_TABLE_RESPONSE)) + // }); + // + // let get_table_response = client + // .get_table("catalog_name", "schema_name", "table_name") + // .await + // .unwrap(); + // assert!(matches!(get_table_response, GetTableResponse::Success(_))); + // } } diff --git a/crates/core/src/data_catalog/unity/models.rs b/crates/catalog-unity/src/models.rs similarity index 100% rename from crates/core/src/data_catalog/unity/models.rs rename to crates/catalog-unity/src/models.rs diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 57a9496070..8174941f72 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -81,7 +81,7 @@ dashmap = "6" errno = "0.3" either = "1.8" fix-hidden-lifetime-bug = "0.2" -hyper = { version = "0.14", optional = true } +#hyper = { version = "0.14", optional = true } indexmap = "2.2.1" itertools = "0.13" lazy_static = "1" @@ -97,10 +97,10 @@ tracing = { workspace = true } rand = "0.8" z85 = "3.0.5" maplit = "1" -sqlparser = { version = "0.51" } +sqlparser = { version = "0.52.0" } # Unity -reqwest = { version = "0.11.18", default-features = false, features = [ +reqwest = { version = "0.12.9", default-features = false, features = [ "rustls-tls", "json", ], optional = true } @@ -111,7 +111,7 @@ ctor = "0" deltalake-test = { path = "../test", features = ["datafusion"] } dotenvy = "0" fs_extra = "1.2.0" -hyper = { version = "0.14", features = ["server"] } +#hyper = { version = "0.14", features = ["server"] } maplit = "1" pretty_assertions = "1.2.1" pretty_env_logger = "0.5.0" @@ -137,4 +137,4 @@ datafusion = [ datafusion-ext = ["datafusion"] json = ["parquet/json"] python = ["arrow/pyarrow"] -unity-experimental = ["reqwest", "hyper"] +unity-experimental = ["reqwest"] diff --git a/crates/core/src/data_catalog/mod.rs b/crates/core/src/data_catalog/mod.rs index eaa02ff09a..5ae5e9aa23 100644 --- a/crates/core/src/data_catalog/mod.rs +++ b/crates/core/src/data_catalog/mod.rs @@ -2,15 +2,8 @@ use std::fmt::Debug; -#[cfg(feature = "unity-experimental")] -pub use unity::*; - -#[cfg(feature = "unity-experimental")] -pub mod client; #[cfg(feature = "datafusion")] pub mod storage; -#[cfg(feature = "unity-experimental")] -pub mod unity; /// A result type for data catalog implementations pub type DataCatalogResult = Result; @@ -27,15 +20,6 @@ pub enum DataCatalogError { source: Box, }, - /// A generic error qualified in the message - #[cfg(feature = "unity-experimental")] - #[error("{source}")] - Retry { - /// Error message - #[from] - source: client::retry::RetryError, - }, - #[error("Request error: {source}")] #[cfg(feature = "unity-experimental")] /// Error from reqwest library diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 034781b85c..1af1f566c5 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -229,7 +229,7 @@ fn _arrow_schema(snapshot: &Snapshot, wrap_partitions: bool) -> DeltaResult( snapshot: &'a EagerSnapshot, filters: &[Expr], -) -> DeltaResult + 'a> { +) -> DeltaResult + 'a + use<'a>> { if let Some(Some(predicate)) = (!filters.is_empty()).then_some(conjunction(filters.iter().cloned())) { diff --git a/crates/core/tests/command_optimize.rs b/crates/core/tests/command_optimize.rs index 13cbd168e4..e96ec08a6e 100644 --- a/crates/core/tests/command_optimize.rs +++ b/crates/core/tests/command_optimize.rs @@ -77,8 +77,8 @@ fn generate_random_batch>( let s = partition.into(); for _ in 0..rows { - x_vec.push(rng.gen()); - y_vec.push(rng.gen()); + x_vec.push(rng.r#gen()); + y_vec.push(rng.r#gen()); date_vec.push(s.clone()); } diff --git a/crates/core/tests/command_restore.rs b/crates/core/tests/command_restore.rs index 5013556ab8..9ac3f331da 100644 --- a/crates/core/tests/command_restore.rs +++ b/crates/core/tests/command_restore.rs @@ -74,8 +74,8 @@ fn get_record_batch() -> RecordBatch { let mut rng = rand::thread_rng(); for _ in 0..10 { - id_vec.push(rng.gen()); - value_vec.push(rng.gen()); + id_vec.push(rng.r#gen()); + value_vec.push(rng.r#gen()); } let schema = ArrowSchema::new(vec![ diff --git a/crates/gcp/tests/context.rs b/crates/gcp/tests/context.rs index 4bcc2c1b3b..5dc0f8cb44 100644 --- a/crates/gcp/tests/context.rs +++ b/crates/gcp/tests/context.rs @@ -76,10 +76,12 @@ impl StorageIntegration for GcpIntegration { let account_path = self.temp_dir.path().join("gcs.json"); info!("account_path: {account_path:?}"); std::fs::write(&account_path, serde_json::to_vec(&token).unwrap()).unwrap(); - std::env::set_var( - "GOOGLE_SERVICE_ACCOUNT", - account_path.as_path().to_str().unwrap(), - ); + unsafe { + std::env::set_var( + "GOOGLE_SERVICE_ACCOUNT", + account_path.as_path().to_str().unwrap(), + ); + } } fn bucket_name(&self) -> String { diff --git a/crates/sql/src/logical_plan.rs b/crates/sql/src/logical_plan.rs index 9f154c0204..27da52a96d 100644 --- a/crates/sql/src/logical_plan.rs +++ b/crates/sql/src/logical_plan.rs @@ -33,7 +33,7 @@ impl DeltaStatement { impl Display for Wrapper<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.0 { - DeltaStatement::Vacuum(Vacuum { + &DeltaStatement::Vacuum(Vacuum { ref table, ref dry_run, ref retention_hours, diff --git a/python/Cargo.toml b/python/Cargo.toml index bb6fbba621..8f44393819 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -46,8 +46,8 @@ reqwest = { version = "*", features = ["native-tls-vendored"] } deltalake-mount = { path = "../crates/mount" } [dependencies.pyo3] -version = "0.22.2" -features = ["extension-module", "abi3", "abi3-py39"] +version = "0.22.6" +features = ["extension-module", "abi3", "abi3-py39", "gil-refs"] [dependencies.deltalake] path = "../crates/deltalake" From a625768971aec72b20207d39992a33d776decc2d Mon Sep 17 00:00:00 2001 From: Stephen Carman Date: Sun, 15 Dec 2024 12:17:45 -0500 Subject: [PATCH 2/4] feat: move unity catalog integration into its own crate Signed-off-by: Stephen Carman --- Cargo.toml | 2 +- crates/aws/src/lib.rs | 1 + crates/aws/src/storage.rs | 114 ++++--- crates/aws/tests/common.rs | 2 +- crates/catalog-glue/src/lib.rs | 2 + crates/catalog-unity/Cargo.toml | 16 +- .../catalog-unity/src/client/mock_server.rs | 94 ------ crates/catalog-unity/src/client/mod.rs | 31 +- crates/catalog-unity/src/client/retry.rs | 282 +----------------- crates/catalog-unity/src/credential.rs | 253 ++++++++-------- crates/catalog-unity/src/datafusion.rs | 4 +- crates/catalog-unity/src/lib.rs | 204 ++++++------- crates/core/Cargo.toml | 15 +- crates/core/src/data_catalog/mod.rs | 31 +- crates/core/src/data_catalog/storage/mod.rs | 4 +- crates/core/src/delta_datafusion/expr.rs | 8 +- crates/core/src/delta_datafusion/mod.rs | 4 +- crates/core/src/kernel/models/actions.rs | 12 +- crates/core/src/kernel/snapshot/mod.rs | 2 +- crates/core/src/operations/merge/filter.rs | 7 +- crates/core/src/operations/transaction/mod.rs | 2 +- crates/core/src/writer/stats.rs | 1 + crates/deltalake/Cargo.toml | 3 +- 23 files changed, 361 insertions(+), 733 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index daebaaa2eb..e1832c2349 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ resolver = "2" [workspace.package] authors = ["Qingping Hou "] -rust-version = "1.85" +rust-version = "1.80" keywords = ["deltalake", "delta", "datalake"] readme = "README.md" edition = "2021" diff --git a/crates/aws/src/lib.rs b/crates/aws/src/lib.rs index 147ea1bdc6..981d368aa8 100644 --- a/crates/aws/src/lib.rs +++ b/crates/aws/src/lib.rs @@ -728,6 +728,7 @@ fn extract_version_from_filename(name: &str) -> Option { mod tests { use super::*; use aws_sdk_sts::config::ProvideCredentials; + use object_store::memory::InMemory; use serial_test::serial; diff --git a/crates/aws/src/storage.rs b/crates/aws/src/storage.rs index 80e912a0d3..019071a60f 100644 --- a/crates/aws/src/storage.rs +++ b/crates/aws/src/storage.rs @@ -529,30 +529,30 @@ mod tests { fn storage_options_default_test() { ScopedEnv::run(|| { clear_env_of_aws_keys(); - unsafe { - std::env::set_var(constants::AWS_ENDPOINT_URL, "http://localhost"); - std::env::set_var(constants::AWS_REGION, "us-west-1"); - std::env::set_var(constants::AWS_PROFILE, "default"); - std::env::set_var(constants::AWS_ACCESS_KEY_ID, "default_key_id"); - std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "default_secret_key"); - std::env::set_var(constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); - std::env::set_var( - constants::AWS_IAM_ROLE_ARN, - "arn:aws:iam::123456789012:role/some_role", - ); - std::env::set_var(constants::AWS_IAM_ROLE_SESSION_NAME, "session_name"); - std::env::set_var( - #[allow(deprecated)] - constants::AWS_S3_ASSUME_ROLE_ARN, - "arn:aws:iam::123456789012:role/some_role", - ); - std::env::set_var( - #[allow(deprecated)] - constants::AWS_S3_ROLE_SESSION_NAME, - "session_name", - ); - std::env::set_var(constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); - } + + std::env::set_var(constants::AWS_ENDPOINT_URL, "http://localhost"); + std::env::set_var(constants::AWS_REGION, "us-west-1"); + std::env::set_var(constants::AWS_PROFILE, "default"); + std::env::set_var(constants::AWS_ACCESS_KEY_ID, "default_key_id"); + std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "default_secret_key"); + std::env::set_var(constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); + std::env::set_var( + constants::AWS_IAM_ROLE_ARN, + "arn:aws:iam::123456789012:role/some_role", + ); + std::env::set_var(constants::AWS_IAM_ROLE_SESSION_NAME, "session_name"); + std::env::set_var( + #[allow(deprecated)] + constants::AWS_S3_ASSUME_ROLE_ARN, + "arn:aws:iam::123456789012:role/some_role", + ); + std::env::set_var( + #[allow(deprecated)] + constants::AWS_S3_ROLE_SESSION_NAME, + "session_name", + ); + std::env::set_var(constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); + let options = S3StorageOptions::try_default().unwrap(); assert_eq!( S3StorageOptions { @@ -585,7 +585,8 @@ mod tests { fn storage_options_with_only_region_and_credentials() { ScopedEnv::run(|| { clear_env_of_aws_keys(); - unsafe { std::env::remove_var(constants::AWS_ENDPOINT_URL); } + std::env::remove_var(constants::AWS_ENDPOINT_URL); + let options = S3StorageOptions::from_map(&hashmap! { constants::AWS_REGION.to_string() => "eu-west-1".to_string(), constants::AWS_ACCESS_KEY_ID.to_string() => "test".to_string(), @@ -676,28 +677,26 @@ mod tests { fn storage_options_mixed_test() { ScopedEnv::run(|| { clear_env_of_aws_keys(); - unsafe { - std::env::set_var(constants::AWS_ENDPOINT_URL, "http://localhost"); - std::env::set_var( - constants::AWS_ENDPOINT_URL_DYNAMODB, - "http://localhost:dynamodb", - ); - std::env::set_var(constants::AWS_REGION, "us-west-1"); - std::env::set_var(constants::AWS_PROFILE, "default"); - std::env::set_var(constants::AWS_ACCESS_KEY_ID, "wrong_key_id"); - std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "wrong_secret_key"); - std::env::set_var(constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); - std::env::set_var( - constants::AWS_IAM_ROLE_ARN, - "arn:aws:iam::123456789012:role/some_role", - ); - std::env::set_var(constants::AWS_IAM_ROLE_SESSION_NAME, "session_name"); - std::env::set_var(constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); - - std::env::set_var(constants::AWS_S3_POOL_IDLE_TIMEOUT_SECONDS, "1"); - std::env::set_var(constants::AWS_STS_POOL_IDLE_TIMEOUT_SECONDS, "2"); - std::env::set_var(constants::AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES, "3"); - } + std::env::set_var(constants::AWS_ENDPOINT_URL, "http://localhost"); + std::env::set_var( + constants::AWS_ENDPOINT_URL_DYNAMODB, + "http://localhost:dynamodb", + ); + std::env::set_var(constants::AWS_REGION, "us-west-1"); + std::env::set_var(constants::AWS_PROFILE, "default"); + std::env::set_var(constants::AWS_ACCESS_KEY_ID, "wrong_key_id"); + std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "wrong_secret_key"); + std::env::set_var(constants::AWS_S3_LOCKING_PROVIDER, "dynamodb"); + std::env::set_var( + constants::AWS_IAM_ROLE_ARN, + "arn:aws:iam::123456789012:role/some_role", + ); + std::env::set_var(constants::AWS_IAM_ROLE_SESSION_NAME, "session_name"); + std::env::set_var(constants::AWS_WEB_IDENTITY_TOKEN_FILE, "token_file"); + + std::env::set_var(constants::AWS_S3_POOL_IDLE_TIMEOUT_SECONDS, "1"); + std::env::set_var(constants::AWS_STS_POOL_IDLE_TIMEOUT_SECONDS, "2"); + std::env::set_var(constants::AWS_S3_GET_INTERNAL_SERVER_ERROR_RETRIES, "3"); let options = S3StorageOptions::from_map(&hashmap! { constants::AWS_ACCESS_KEY_ID.to_string() => "test_id_mixed".to_string(), constants::AWS_SECRET_ACCESS_KEY.to_string() => "test_secret_mixed".to_string(), @@ -769,12 +768,10 @@ mod tests { ScopedEnv::run(|| { clear_env_of_aws_keys(); let raw_options = hashmap! {}; - unsafe { - std::env::set_var(constants::AWS_ACCESS_KEY_ID, "env_key"); - std::env::set_var(constants::AWS_ENDPOINT_URL, "env_key"); - std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "env_key"); - std::env::set_var(constants::AWS_REGION, "env_key"); - } + std::env::set_var(constants::AWS_ACCESS_KEY_ID, "env_key"); + std::env::set_var(constants::AWS_ENDPOINT_URL, "env_key"); + std::env::set_var(constants::AWS_SECRET_ACCESS_KEY, "env_key"); + std::env::set_var(constants::AWS_REGION, "env_key"); let combined_options = S3ObjectStoreFactory {}.with_env_s3(&StorageOptions(raw_options)); @@ -797,12 +794,11 @@ mod tests { "AWS_SECRET_ACCESS_KEY".to_string() => "options_key".to_string(), "AWS_REGION".to_string() => "options_key".to_string() }; - unsafe { - std::env::set_var("aws_access_key_id", "env_key"); - std::env::set_var("aws_endpoint", "env_key"); - std::env::set_var("aws_secret_access_key", "env_key"); - std::env::set_var("aws_region", "env_key"); - } + std::env::set_var("aws_access_key_id", "env_key"); + std::env::set_var("aws_endpoint", "env_key"); + std::env::set_var("aws_secret_access_key", "env_key"); + std::env::set_var("aws_region", "env_key"); + let combined_options = S3ObjectStoreFactory {}.with_env_s3(&StorageOptions(raw_options)); diff --git a/crates/aws/tests/common.rs b/crates/aws/tests/common.rs index 1d64d79b30..8f4adb7523 100644 --- a/crates/aws/tests/common.rs +++ b/crates/aws/tests/common.rs @@ -3,7 +3,7 @@ use deltalake_aws::constants; use deltalake_aws::register_handlers; use deltalake_aws::storage::*; use deltalake_test::utils::*; -use rand::{random, Rng}; +use rand::random; use std::process::{Command, ExitStatus, Stdio}; #[derive(Clone, Debug)] diff --git a/crates/catalog-glue/src/lib.rs b/crates/catalog-glue/src/lib.rs index e9ef449be2..089ce56ce2 100644 --- a/crates/catalog-glue/src/lib.rs +++ b/crates/catalog-glue/src/lib.rs @@ -60,6 +60,8 @@ const PLACEHOLDER_SUFFIX: &str = "-__PLACEHOLDER__"; #[async_trait::async_trait] impl DataCatalog for GlueDataCatalog { + type Error = DataCatalogError; + /// Get the table storage location from the Glue Data Catalog async fn get_table_storage_location( &self, diff --git a/crates/catalog-unity/Cargo.toml b/crates/catalog-unity/Cargo.toml index 051dcb05e1..8a0827386b 100644 --- a/crates/catalog-unity/Cargo.toml +++ b/crates/catalog-unity/Cargo.toml @@ -12,16 +12,18 @@ repository.workspace = true rust-version.workspace = true [dependencies] -async-trait = { workspace = true } -deltalake-core = { version = "0.22", path = "../core", features = ["unity-experimental"] } -thiserror = { workspace = true } +async-trait.workspace = true +tokio.workspace = true +serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +deltalake-core = { version = "0.22", path = "../core" } reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json", "http2"] } +reqwest-retry = "0.7" +reqwest-middleware = "0.4.0" rand = "0.8" futures = "0.3" chrono = "0.4" -tokio.workspace = true -serde.workspace = true -serde_json.workspace = true dashmap = "6" tracing = "0.1" datafusion = { version = "43", optional = true } @@ -30,7 +32,7 @@ datafusion-common = { version = "43", optional = true } [dev-dependencies] tokio = { version = "1", features = ["macros", "rt-multi-thread"] } tempfile = "3" -httpmock = { version = "0.8.0-alpha.1", features = [] } +httpmock = { version = "0.8.0-alpha.1" } [features] default = [] diff --git a/crates/catalog-unity/src/client/mock_server.rs b/crates/catalog-unity/src/client/mock_server.rs index 9bed67e75c..e69de29bb2 100644 --- a/crates/catalog-unity/src/client/mock_server.rs +++ b/crates/catalog-unity/src/client/mock_server.rs @@ -1,94 +0,0 @@ -use std::collections::VecDeque; -use std::convert::Infallible; -use std::net::SocketAddr; -use std::sync::Arc; - -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Request, Response, Server}; -use parking_lot::Mutex; -use tokio::sync::oneshot; -use tokio::task::JoinHandle; - -pub type ResponseFn = Box) -> Response + Send>; - -/// A mock server -pub struct MockServer { - responses: Arc>>, - shutdown: oneshot::Sender<()>, - handle: JoinHandle<()>, - url: String, -} - -impl Default for MockServer { - fn default() -> Self { - Self::new() - } -} - -impl MockServer { - pub fn new() -> Self { - let responses: Arc>> = - Arc::new(Mutex::new(VecDeque::with_capacity(10))); - - let r = Arc::clone(&responses); - let make_service = make_service_fn(move |_conn| { - let r = Arc::clone(&r); - async move { - Ok::<_, Infallible>(service_fn(move |req| { - let r = Arc::clone(&r); - async move { - Ok::<_, Infallible>(match r.lock().pop_front() { - Some(r) => r(req), - None => Response::new(Body::from("Hello World")), - }) - } - })) - } - }); - - let (shutdown, rx) = oneshot::channel::<()>(); - let server = Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service); - - let url = format!("http://{}", server.local_addr()); - - let handle = tokio::spawn(async move { - server - .with_graceful_shutdown(async { - rx.await.ok(); - }) - .await - .unwrap() - }); - - Self { - responses, - shutdown, - handle, - url, - } - } - - /// The url of the mock server - pub fn url(&self) -> &str { - &self.url - } - - /// Add a response - pub fn push(&self, response: Response) { - self.push_fn(|_| response) - } - - /// Add a response function - pub fn push_fn(&self, f: F) - where - F: FnOnce(Request) -> Response + Send + 'static, - { - self.responses.lock().push_back(Box::new(f)) - } - - /// Shutdown the mock server - pub async fn shutdown(self) { - let _ = self.shutdown.send(()); - self.handle.await.unwrap() - } -} diff --git a/crates/catalog-unity/src/client/mod.rs b/crates/catalog-unity/src/client/mod.rs index 5f4d981491..e88d0fa040 100644 --- a/crates/catalog-unity/src/client/mod.rs +++ b/crates/catalog-unity/src/client/mod.rs @@ -1,15 +1,18 @@ //! Generic utilities reqwest based Catalog implementations pub mod backoff; -// #[cfg(test)] -// pub mod mock_server; #[allow(unused)] pub mod pagination; pub mod retry; pub mod token; +use crate::client::retry::RetryConfig; +use crate::UnityCatalogError; +use deltalake_core::data_catalog::DataCatalogResult; use reqwest::header::{HeaderMap, HeaderValue}; -use reqwest::{Client, ClientBuilder, Proxy}; +use reqwest::{ClientBuilder, Proxy}; +use reqwest_middleware::ClientWithMiddleware; +use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use std::time::Duration; fn map_client_error(e: reqwest::Error) -> super::DataCatalogError { @@ -38,6 +41,7 @@ pub struct ClientOptions { http2_keep_alive_while_idle: bool, http1_only: bool, http2_only: bool, + retry_config: Option, } impl ClientOptions { @@ -164,7 +168,12 @@ impl ClientOptions { self } - pub(crate) fn client(&self) -> super::DataCatalogResult { + pub fn with_retry_config(mut self, cfg: RetryConfig) -> Self { + self.retry_config = Some(cfg); + self + } + + pub(crate) fn client(&self) -> DataCatalogResult { let mut builder = ClientBuilder::new(); match &self.user_agent { @@ -221,9 +230,19 @@ impl ClientOptions { builder = builder.danger_accept_invalid_certs(self.allow_insecure) } - builder + let inner_client = builder .https_only(!self.allow_http) .build() - .map_err(map_client_error) + .map_err(UnityCatalogError::from)?; + let retry_policy = self + .retry_config + .as_ref() + .map(|retry| retry.into()) + .unwrap_or(ExponentialBackoff::builder().build_with_max_retries(3)); + + let middleware = RetryTransientMiddleware::new_with_policy(retry_policy); + Ok(reqwest_middleware::ClientBuilder::new(inner_client) + .with(middleware) + .build()) } } diff --git a/crates/catalog-unity/src/client/retry.rs b/crates/catalog-unity/src/client/retry.rs index b770080bcd..9b3828274e 100644 --- a/crates/catalog-unity/src/client/retry.rs +++ b/crates/catalog-unity/src/client/retry.rs @@ -1,14 +1,10 @@ //! A shared HTTP client implementation incorporating retries -use std::error::Error as StdError; -use futures::future::BoxFuture; -use futures::FutureExt; -use reqwest::header::LOCATION; -use reqwest::{Response, StatusCode}; -use std::time::{Duration, Instant}; -use tracing::info; +use super::backoff::BackoffConfig; use deltalake_core::DataCatalogError; -use super::backoff::{Backoff, BackoffConfig}; +use reqwest::StatusCode; +use reqwest_retry::policies::ExponentialBackoff; +use std::time::Duration; /// Retry request error #[derive(Debug)] @@ -61,12 +57,6 @@ impl From for std::io::Error { } } -impl From for reqwest::Error { - fn from(value: RetryError) -> Self { - Into::into(value) - } -} - impl From for DataCatalogError { fn from(value: RetryError) -> Self { DataCatalogError::Generic { @@ -118,263 +108,11 @@ impl Default for RetryConfig { } } -/// Trait to rend requests with retry -pub trait RetryExt { - /// Dispatch a request with the given retry configuration - /// - /// # Panic - /// - /// This will panic if the request body is a stream - fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result>; -} - -impl RetryExt for reqwest::RequestBuilder { - fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result> { - let mut backoff = Backoff::new(&config.backoff); - let max_retries = config.max_retries; - let retry_timeout = config.retry_timeout; - - async move { - let mut retries = 0; - let now = Instant::now(); - - loop { - let s = self.try_clone().expect("request body must be cloneable"); - match s.send().await { - Ok(r) => match r.error_for_status_ref() { - Ok(_) if r.status().is_success() => return Ok(r), - Ok(r) => { - let is_bare_redirect = r.status().is_redirection() && !r.headers().contains_key(LOCATION); - let message = match is_bare_redirect { - true => "Received redirect without LOCATION, this normally indicates an incorrectly configured region".to_string(), - // Not actually sure if this is reachable, but here for completeness - false => format!("request unsuccessful: {}", r.status()), - }; - - return Err(RetryError{ - message, - retries, - source: None, - }) - } - Err(e) => { - let status = r.status(); - - if retries == max_retries - || now.elapsed() > retry_timeout - || !status.is_server_error() { - - // Get the response message if returned a client error - let message = match status.is_client_error() { - true => match r.text().await { - Ok(message) if !message.is_empty() => message, - Ok(_) => "No Body".to_string(), - Err(e) => format!("error getting response body: {e}") - } - false => status.to_string(), - }; - - return Err(RetryError{ - message, - retries, - source: Some(e), - }) - - } - - let sleep = backoff.tick(); - retries += 1; - info!("Encountered server error, backing off for {} seconds, retry {} of {}", sleep.as_secs_f32(), retries, max_retries); - tokio::time::sleep(sleep).await; - } - }, - Err(e) => - { - let mut do_retry = false; - if let Some(source) = e.source() { - if let Some(e) = source.downcast_ref::() { - if e.is_timeout() || e.is_request() || e.is_connect() { - do_retry = true; - } - } - } - - if retries == max_retries - || now.elapsed() > retry_timeout - || !do_retry { - - return Err(RetryError{ - retries, - message: "request error".to_string(), - source: Some(e) - }) - } - let sleep = backoff.tick(); - retries += 1; - info!("Encountered request error ({}) backing off for {} seconds, retry {} of {}", e, sleep.as_secs_f32(), retries, max_retries); - tokio::time::sleep(sleep).await; - } - } - } - } - .boxed() +impl From<&RetryConfig> for ExponentialBackoff { + fn from(val: &RetryConfig) -> ExponentialBackoff { + ExponentialBackoff::builder() + .retry_bounds(val.backoff.init_backoff, val.backoff.max_backoff) + .base(val.backoff.base as u32) + .build_with_max_retries(val.max_retries as u32) } } - -#[cfg(test)] -mod tests { - // use super::super::mock_server::MockServer; - use super::RetryConfig; - use super::RetryExt; - // use hyper::header::LOCATION; - // use hyper::{ Response}; - use reqwest::{Client, Method, StatusCode}; - use std::time::Duration; - - // #[tokio::test] - // async fn test_retry() { - // let mock = MockServer::new(); - // - // let retry = RetryConfig { - // backoff: Default::default(), - // max_retries: 2, - // retry_timeout: Duration::from_secs(1000), - // }; - // - // let client = Client::new(); - // let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry); - // - // // Simple request should work - // let r = do_request().await.unwrap(); - // assert_eq!(r.status(), StatusCode::OK); - // - // // Returns client errors immediately with status message - // mock.push( - // Response::builder() - // .status(StatusCode::BAD_REQUEST) - // .body(Body::from("cupcakes")) - // .unwrap(), - // ); - // - // let e = do_request().await.unwrap_err(); - // assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); - // assert_eq!(e.retries, 0); - // assert_eq!(&e.message, "cupcakes"); - // - // // Handles client errors with no payload - // mock.push( - // Response::builder() - // .status(StatusCode::BAD_REQUEST) - // .body(Body::empty()) - // .unwrap(), - // ); - // - // let e = do_request().await.unwrap_err(); - // assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); - // assert_eq!(e.retries, 0); - // assert_eq!(&e.message, "No Body"); - // - // // Should retry server error request - // mock.push( - // Response::builder() - // .status(StatusCode::BAD_GATEWAY) - // .body(Body::empty()) - // .unwrap(), - // ); - // - // let r = do_request().await.unwrap(); - // assert_eq!(r.status(), StatusCode::OK); - // - // // Accepts 204 status code - // mock.push( - // Response::builder() - // .status(StatusCode::NO_CONTENT) - // .body(Body::empty()) - // .unwrap(), - // ); - // - // let r = do_request().await.unwrap(); - // assert_eq!(r.status(), StatusCode::NO_CONTENT); - // - // // Follows 402 redirects - // mock.push( - // Response::builder() - // .status(StatusCode::FOUND) - // .header(LOCATION, "/foo") - // .body(Body::empty()) - // .unwrap(), - // ); - // - // let r = do_request().await.unwrap(); - // assert_eq!(r.status(), StatusCode::OK); - // assert_eq!(r.url().path(), "/foo"); - // - // // Follows 401 redirects - // mock.push( - // Response::builder() - // .status(StatusCode::FOUND) - // .header(LOCATION, "/bar") - // .body(Body::empty()) - // .unwrap(), - // ); - // - // let r = do_request().await.unwrap(); - // assert_eq!(r.status(), StatusCode::OK); - // assert_eq!(r.url().path(), "/bar"); - // - // // Handles redirect loop - // for _ in 0..10 { - // mock.push( - // Response::builder() - // .status(StatusCode::FOUND) - // .header(LOCATION, "/bar") - // .body(Body::empty()) - // .unwrap(), - // ); - // } - // - // let e = do_request().await.unwrap_err().to_string(); - // assert!(e.ends_with("too many redirects"), "{}", e); - // - // // Handles redirect missing location - // mock.push( - // Response::builder() - // .status(StatusCode::FOUND) - // .body(Body::empty()) - // .unwrap(), - // ); - // - // let e = do_request().await.unwrap_err(); - // assert_eq!(e.message, "Received redirect without LOCATION, this normally indicates an incorrectly configured region"); - // - // // Gives up after the retrying the specified number of times - // for _ in 0..=retry.max_retries { - // mock.push( - // Response::builder() - // .status(StatusCode::BAD_GATEWAY) - // .body(Body::from("ignored")) - // .unwrap(), - // ); - // } - // - // let e = do_request().await.unwrap_err(); - // assert_eq!(e.retries, retry.max_retries); - // assert_eq!(e.message, "502 Bad Gateway"); - // - // // Panic results in an incomplete message error in the client - // mock.push_fn(|_| panic!()); - // let r = do_request().await.unwrap(); - // assert_eq!(r.status(), StatusCode::OK); - // - // // Gives up after retrying mulitiple panics - // for _ in 0..=retry.max_retries { - // mock.push_fn(|_| panic!()); - // } - // let e = do_request().await.unwrap_err(); - // assert_eq!(e.retries, retry.max_retries); - // assert_eq!(e.message, "request error"); - // - // // Shutdown - // mock.shutdown().await - // } -} diff --git a/crates/catalog-unity/src/credential.rs b/crates/catalog-unity/src/credential.rs index 8238b9f76c..b6b21b47eb 100644 --- a/crates/catalog-unity/src/credential.rs +++ b/crates/catalog-unity/src/credential.rs @@ -4,13 +4,12 @@ use std::process::Command; use std::time::{Duration, Instant}; use reqwest::header::{HeaderValue, ACCEPT}; -use reqwest::{Client, Method}; +use reqwest::Method; +use reqwest_middleware::ClientWithMiddleware; use serde::Deserialize; use super::UnityCatalogError; -use crate::client::retry::{RetryConfig, RetryExt}; use crate::client::token::{TemporaryToken, TokenCache}; -use crate::DataCatalogResult; // https://learn.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/authentication @@ -37,9 +36,8 @@ pub trait TokenCredential: std::fmt::Debug + Send + Sync + 'static { /// get the token async fn fetch_token( &self, - client: &Client, - retry: &RetryConfig, - ) -> DataCatalogResult>; + client: &ClientWithMiddleware, + ) -> Result, UnityCatalogError>; } /// Provides credentials for use when signing requests @@ -95,9 +93,8 @@ impl TokenCredential for ClientSecretOAuthProvider { /// Fetch a token async fn fetch_token( &self, - client: &Client, - retry: &RetryConfig, - ) -> DataCatalogResult> { + client: &ClientWithMiddleware, + ) -> Result, UnityCatalogError> { let response: TokenResponse = client .request(Method::POST, &self.token_url) .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)) @@ -107,10 +104,12 @@ impl TokenCredential for ClientSecretOAuthProvider { ("scope", &format!("{}/.default", DATABRICKS_RESOURCE_SCOPE)), ("grant_type", "client_credentials"), ]) - .send_retry(retry) - .await? + .send() + .await + .map_err(UnityCatalogError::from)? .json() - .await?; + .await + .map_err(UnityCatalogError::from)?; Ok(TemporaryToken { token: response.access_token, @@ -167,9 +166,8 @@ impl TokenCredential for AzureCliCredential { /// Fetch a token async fn fetch_token( &self, - _client: &Client, - _retry: &RetryConfig, - ) -> DataCatalogResult> { + _client: &ClientWithMiddleware, + ) -> Result, UnityCatalogError> { // on window az is a cmd and it should be called like this // see https://doc.rust-lang.org/nightly/std/process/struct.Command.html let program = if cfg!(target_os = "windows") { @@ -281,9 +279,8 @@ impl TokenCredential for WorkloadIdentityOAuthProvider { /// Fetch a token async fn fetch_token( &self, - client: &Client, - retry: &RetryConfig, - ) -> DataCatalogResult> { + client: &ClientWithMiddleware, + ) -> Result, UnityCatalogError> { let token_str = std::fs::read_to_string(&self.federated_token_file) .map_err(|_| UnityCatalogError::FederatedTokenFile)?; @@ -301,10 +298,12 @@ impl TokenCredential for WorkloadIdentityOAuthProvider { ("scope", &format!("{}/.default", DATABRICKS_RESOURCE_SCOPE)), ("grant_type", "client_credentials"), ]) - .send_retry(retry) - .await? + .send() + .await + .map_err(UnityCatalogError::from)? .json() - .await?; + .await + .map_err(UnityCatalogError::from)?; Ok(TemporaryToken { token: response.access_token, @@ -340,7 +339,7 @@ pub struct ImdsManagedIdentityOAuthProvider { client_id: Option, object_id: Option, msi_res_id: Option, - client: Client, + client: ClientWithMiddleware, } impl ImdsManagedIdentityOAuthProvider { @@ -350,7 +349,7 @@ impl ImdsManagedIdentityOAuthProvider { object_id: Option, msi_res_id: Option, msi_endpoint: Option, - client: Client, + client: ClientWithMiddleware, ) -> Self { let msi_endpoint = msi_endpoint .unwrap_or_else(|| "http://169.254.169.254/metadata/identity/oauth2/token".to_owned()); @@ -370,9 +369,8 @@ impl TokenCredential for ImdsManagedIdentityOAuthProvider { /// Fetch a token async fn fetch_token( &self, - _client: &Client, - retry: &RetryConfig, - ) -> DataCatalogResult> { + _client: &ClientWithMiddleware, + ) -> Result, UnityCatalogError> { let resource_scope = format!("{}/.default", DATABRICKS_RESOURCE_SCOPE); let mut query_items = vec![ ("api-version", MSI_API_VERSION), @@ -403,7 +401,13 @@ impl TokenCredential for ImdsManagedIdentityOAuthProvider { builder = builder.header("x-identity-header", val); }; - let response: MsiTokenResponse = builder.send_retry(retry).await?.json().await?; + let response: MsiTokenResponse = builder + .send() + .await + .map_err(UnityCatalogError::from)? + .json() + .await + .map_err(UnityCatalogError::from)?; Ok(TemporaryToken { token: response.access_token, @@ -415,113 +419,94 @@ impl TokenCredential for ImdsManagedIdentityOAuthProvider { #[cfg(test)] mod tests { use super::*; - // use crate::client::mock_server::MockServer; - use futures::executor::block_on; - - // use hyper::{ Response}; - use reqwest::{Client, Method}; + use httpmock::prelude::*; + use reqwest::Client; use tempfile::NamedTempFile; - // #[tokio::test] - // async fn test_managed_identity() { - // let server = MockServer::new(); - // - // std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret"); - // - // let endpoint = server.url(); - // let client = Client::new(); - // let retry_config = RetryConfig::default(); - // - // // Test IMDS - // server.push_fn(|req| { - // assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token"); - // assert!(req.uri().query().unwrap().contains("client_id=client_id")); - // assert_eq!(req.method(), &Method::GET); - // let t = req - // .headers() - // .get("x-identity-header") - // .unwrap() - // .to_str() - // .unwrap(); - // assert_eq!(t, "env-secret"); - // let t = req.headers().get("metadata").unwrap().to_str().unwrap(); - // assert_eq!(t, "true"); - // Response::new(Body::from( - // r#" - // { - // "access_token": "TOKEN", - // "refresh_token": "", - // "expires_in": "3599", - // "expires_on": "1506484173", - // "not_before": "1506480273", - // "resource": "https://management.azure.com/", - // "token_type": "Bearer" - // } - // "#, - // )) - // }); - // - // let credential = ImdsManagedIdentityOAuthProvider::new( - // Some("client_id".into()), - // None, - // None, - // Some(format!("{endpoint}/metadata/identity/oauth2/token")), - // client.clone(), - // ); - // - // let token = credential - // .fetch_token(&client, &retry_config) - // .await - // .unwrap(); - // - // assert_eq!(&token.token, "TOKEN"); - // } - // - // #[tokio::test] - // async fn test_workload_identity() { - // let server = MockServer::new(); - // let tokenfile = NamedTempFile::new().unwrap(); - // let tenant = "tenant"; - // std::fs::write(tokenfile.path(), "federated-token").unwrap(); - // - // let endpoint = server.url(); - // let client = Client::new(); - // let retry_config = RetryConfig::default(); - // - // // Test IMDS - // server.push_fn(move |req| { - // assert_eq!(req.uri().path(), format!("/{tenant}/oauth2/v2.0/token")); - // assert_eq!(req.method(), &Method::POST); - // let body = block_on(to_bytes(req.into_body())).unwrap(); - // let body = String::from_utf8(body.to_vec()).unwrap(); - // assert!(body.contains("federated-token")); - // Response::new(Body::from( - // r#" - // { - // "access_token": "TOKEN", - // "refresh_token": "", - // "expires_in": 3599, - // "expires_on": "1506484173", - // "not_before": "1506480273", - // "resource": "https://management.azure.com/", - // "token_type": "Bearer" - // } - // "#, - // )) - // }); - // - // let credential = WorkloadIdentityOAuthProvider::new( - // "client_id", - // tokenfile.path().to_str().unwrap(), - // tenant, - // Some(endpoint.to_string()), - // ); - // - // let token = credential - // .fetch_token(&client, &retry_config) - // .await - // .unwrap(); - // - // assert_eq!(&token.token, "TOKEN"); - // } + #[tokio::test] + async fn test_managed_identity() { + let server = MockServer::start_async().await; + + std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret"); + + let client = reqwest_middleware::ClientBuilder::new(Client::new()).build(); + + server + .mock_async(|when, then| { + when.path("/metadata/identity/oauth2/token") + .query_param("client_id", "client_id") + .method("GET") + .header("x-identity-header", "env-secret") + .header("metadata", "true"); + then.body( + r#" + { + "access_token": "TOKEN", + "refresh_token": "", + "expires_in": "3599", + "expires_on": "1506484173", + "not_before": "1506480273", + "resource": "https://management.azure.com/", + "token_type": "Bearer" + } + "#, + ); + }) + .await; + + let credential = ImdsManagedIdentityOAuthProvider::new( + Some("client_id".into()), + None, + None, + Some(server.url("/metadata/identity/oauth2/token")), + client.clone(), + ); + + let token = credential.fetch_token(&client).await.unwrap(); + + assert_eq!(&token.token, "TOKEN"); + } + + #[tokio::test] + async fn test_workload_identity() { + let server = MockServer::start_async().await; + let tokenfile = NamedTempFile::new().unwrap(); + let tenant = "tenant"; + std::fs::write(tokenfile.path(), "federated-token").unwrap(); + + let client = reqwest_middleware::ClientBuilder::new(Client::new()).build(); + + server + .mock_async(|when, then| { + when.path_includes(format!("/{tenant}/oauth2/v2.0/token")) + .method("POST") + .body_includes("federated-token"); + + then.body( + r#" + { + "access_token": "TOKEN", + "refresh_token": "", + "expires_in": 3599, + "expires_on": "1506484173", + "not_before": "1506480273", + "resource": "https://management.azure.com/", + "token_type": "Bearer" + } + "#, + ); + }) + .await; + + let credential = WorkloadIdentityOAuthProvider::new( + "client_id", + tokenfile.path().to_str().unwrap(), + tenant, + Some(server.url(format!("/{tenant}/oauth2/v2.0/token"))), + ); + + let token = credential.fetch_token(&client).await.unwrap(); + + assert_eq!(&token.token, "TOKEN"); + } } diff --git a/crates/catalog-unity/src/datafusion.rs b/crates/catalog-unity/src/datafusion.rs index b77409bf04..23339b0b16 100644 --- a/crates/catalog-unity/src/datafusion.rs +++ b/crates/catalog-unity/src/datafusion.rs @@ -11,7 +11,9 @@ use datafusion::datasource::TableProvider; use datafusion_common::DataFusionError; use tracing::error; -use super::models::{GetTableResponse, ListCatalogsResponse, ListSchemasResponse, ListTableSummariesResponse}; +use super::models::{ + GetTableResponse, ListCatalogsResponse, ListSchemasResponse, ListTableSummariesResponse, +}; use super::{DataCatalogResult, UnityCatalog}; use deltalake_core::DeltaTableBuilder; diff --git a/crates/catalog-unity/src/lib.rs b/crates/catalog-unity/src/lib.rs index 00b02f9f12..b8ccd23865 100644 --- a/crates/catalog-unity/src/lib.rs +++ b/crates/catalog-unity/src/lib.rs @@ -1,9 +1,7 @@ //! Databricks Unity Catalog. -//! -//! This module is gated behind the "unity-experimental" feature. use std::str::FromStr; -use reqwest::header::{HeaderValue, AUTHORIZATION}; +use reqwest::header::{HeaderValue, InvalidHeaderValue, AUTHORIZATION}; use crate::credential::{AzureCliCredential, ClientSecretOAuthProvider, CredentialProvider}; use crate::models::{ @@ -26,7 +24,7 @@ pub mod error; /// Possible errors from the unity-catalog/tables API call #[derive(thiserror::Error, Debug)] -enum UnityCatalogError { +pub enum UnityCatalogError { #[error("GET request error: {source}")] /// Error from reqwest library RequestError { @@ -35,6 +33,13 @@ enum UnityCatalogError { source: reqwest::Error, }, + #[error("Error in middleware: {source}")] + RequestMiddlewareError { + /// The underlying reqwest_middleware::Error + #[from] + source: reqwest_middleware::Error, + }, + /// Request returned error response #[error("Invalid table error: {error_code}: {message}")] InvalidTable { @@ -44,9 +49,11 @@ enum UnityCatalogError { message: String, }, - /// Unknown configuration key - #[error("Unknown configuration key: {0}")] - UnknownConfigKey(String), + #[error("Invalid token for auth header: {header_error}")] + InvalidHeader { + #[from] + header_error: InvalidHeaderValue, + }, /// Unknown configuration key #[error("Missing configuration key: {0}")] @@ -69,10 +76,6 @@ enum UnityCatalogError { impl From for DataCatalogError { fn from(value: UnityCatalogError) -> Self { match value { - UnityCatalogError::UnknownConfigKey(key) => DataCatalogError::UnknownConfigKey { - catalog: "Unity", - key, - }, _ => DataCatalogError::Generic { catalog: "Unity", source: Box::new(value), @@ -221,7 +224,10 @@ impl FromStr for UnityCatalogConfigKey { "workspace_url" | "unity_workspace_url" | "databricks_workspace_url" => { Ok(UnityCatalogConfigKey::WorkspaceUrl) } - _ => Err(UnityCatalogError::UnknownConfigKey(s.into()).into()), + _ => Err(DataCatalogError::UnknownConfigKey { + catalog: "unity", + key: s.to_string(), + }), } } } @@ -451,45 +457,33 @@ impl UnityCatalogBuilder { client, workspace_url, credential, - retry_config: self.retry_config, }) } } /// Databricks Unity Catalog pub struct UnityCatalog { - client: reqwest::Client, + client: reqwest_middleware::ClientWithMiddleware, credential: CredentialProvider, workspace_url: String, - retry_config: RetryConfig, } impl UnityCatalog { - async fn get_credential(&self) -> DataCatalogResult { + async fn get_credential(&self) -> Result { match &self.credential { CredentialProvider::BearerToken(token) => { - // we do the conversion to a HeaderValue here, since it is fallible + // we do the conversion to a HeaderValue here, since it is fallible, // and we want to use it in an infallible function - HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| { - DataCatalogError::Generic { - catalog: "Unity", - source: Box::new(err), - } - }) + Ok(HeaderValue::from_str(&format!("Bearer {token}"))?) } CredentialProvider::TokenCredential(cache, cred) => { let token = cache - .get_or_insert_with(|| cred.fetch_token(&self.client, &self.retry_config)) + .get_or_insert_with(|| cred.fetch_token(&self.client)) .await?; - // we do the conversion to a HeaderValue here, since it is fallible + // we do the conversion to a HeaderValue here, since it is fallible, // and we want to use it in an infallible function - HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| { - DataCatalogError::Generic { - catalog: "Unity", - source: Box::new(err), - } - }) + Ok(HeaderValue::from_str(&format!("Bearer {token}"))?) } } } @@ -509,9 +503,10 @@ impl UnityCatalog { .client .get(format!("{}/catalogs", self.catalog_url())) .header(AUTHORIZATION, token) - .send_retry(&self.retry_config) - .await?; - Ok(resp.json().await?) + .send() + .await + .map_err(UnityCatalogError::from)?; + Ok(resp.json().await.map_err(UnityCatalogError::from)?) } /// List all schemas for a catalog in the metastore. @@ -534,10 +529,10 @@ impl UnityCatalog { .get(format!("{}/schemas", self.catalog_url())) .header(AUTHORIZATION, token) .query(&[("catalog_name", catalog_name.as_ref())]) - - .send_retry(&self.retry_config) - .await?; - Ok(resp.json().await?) + .send() + .await + .map_err(UnityCatalogError::from)?; + Ok(resp.json().await.map_err(UnityCatalogError::from)?) } /// Gets the specified schema within the metastore.# @@ -560,9 +555,10 @@ impl UnityCatalog { schema_name.as_ref() )) .header(AUTHORIZATION, token) - .send_retry(&self.retry_config) - .await?; - Ok(resp.json().await?) + .send() + .await + .map_err(UnityCatalogError::from)?; + Ok(resp.json().await.map_err(UnityCatalogError::from)?) } /// Gets an array of summaries for tables for a schema and catalog within the metastore. @@ -591,10 +587,11 @@ impl UnityCatalog { ("schema_name_pattern", schema_name_pattern.as_ref()), ]) .header(AUTHORIZATION, token) - .send_retry(&self.retry_config) - .await?; + .send() + .await + .map_err(UnityCatalogError::from)?; - Ok(resp.json().await?) + Ok(resp.json().await.map_err(UnityCatalogError::from)?) } /// Gets a table from the metastore for a specific catalog and schema. @@ -609,7 +606,7 @@ impl UnityCatalog { catalog_id: impl AsRef, database_name: impl AsRef, table_name: impl AsRef, - ) -> DataCatalogResult { + ) -> Result { let token = self.get_credential().await?; // https://docs.databricks.com/api-explorer/workspace/tables/get let resp = self @@ -622,7 +619,7 @@ impl UnityCatalog { table_name.as_ref() )) .header(AUTHORIZATION, token) - .send_retry(&self.retry_config) + .send() .await?; Ok(resp.json().await?) @@ -631,13 +628,14 @@ impl UnityCatalog { #[async_trait::async_trait] impl DataCatalog for UnityCatalog { + type Error = UnityCatalogError; /// Get the table storage location from the UnityCatalog async fn get_table_storage_location( &self, catalog_id: Option, database_name: &str, table_name: &str, - ) -> Result { + ) -> Result { match self .get_table( catalog_id.unwrap_or("main".into()), @@ -664,60 +662,64 @@ impl std::fmt::Debug for UnityCatalog { #[cfg(test)] mod tests { + use crate::client::ClientOptions; + use crate::models::tests::{GET_SCHEMA_RESPONSE, GET_TABLE_RESPONSE, LIST_SCHEMAS_RESPONSE}; + use crate::models::*; + use crate::UnityCatalogBuilder; use httpmock::prelude::*; - // #[tokio::test] - // async fn test_unity_client() { - // let server = MockServer::new(); - // - // let options = ClientOptions::default().with_allow_http(true); - // let client = UnityCatalogBuilder::new() - // .with_workspace_url(server.url()) - // .with_bearer_token("bearer_token") - // .with_client_options(options) - // .build() - // .unwrap(); - // - // server.push_fn(move |req| { - // assert_eq!(req.uri().path(), "/api/2.1/unity-catalog/schemas"); - // assert_eq!(req.method(), Method::GET); - // Response::new(Body::from(LIST_SCHEMAS_RESPONSE)) - // }); - // - // let list_schemas_response = client.list_schemas("catalog_name").await.unwrap(); - // assert!(matches!( - // list_schemas_response, - // ListSchemasResponse::Success { .. } - // )); - // - // server.push_fn(move |req| { - // assert_eq!( - // req.uri().path(), - // "/api/2.1/unity-catalog/schemas/catalog_name.schema_name" - // ); - // assert_eq!(req.method(), &Method::GET); - // Response::new(Body::from(GET_SCHEMA_RESPONSE)) - // }); - // - // let get_schema_response = client - // .get_schema("catalog_name", "schema_name") - // .await - // .unwrap(); - // assert!(matches!(get_schema_response, GetSchemaResponse::Success(_))); - // - // server.push_fn(move |req| { - // assert_eq!( - // req.uri().path(), - // "/api/2.1/unity-catalog/tables/catalog_name.schema_name.table_name" - // ); - // assert_eq!(req.method(), &Method::GET); - // Response::new(Body::from(GET_TABLE_RESPONSE)) - // }); - // - // let get_table_response = client - // .get_table("catalog_name", "schema_name", "table_name") - // .await - // .unwrap(); - // assert!(matches!(get_table_response, GetTableResponse::Success(_))); - // } + #[tokio::test] + async fn test_unity_client() { + let server = MockServer::start_async().await; + + let options = ClientOptions::default().with_allow_http(true); + + let client = UnityCatalogBuilder::new() + .with_workspace_url(server.url("")) + .with_bearer_token("bearer_token") + .with_client_options(options) + .build() + .unwrap(); + + server + .mock_async(|when, then| { + when.path("/api/2.1/unity-catalog/schemas").method("GET"); + then.body(LIST_SCHEMAS_RESPONSE); + }) + .await; + + server + .mock_async(|when, then| { + when.path("/api/2.1/unity-catalog/schemas/catalog_name.schema_name") + .method("GET"); + then.body(GET_SCHEMA_RESPONSE); + }) + .await; + + server + .mock_async(|when, then| { + when.path("/api/2.1/unity-catalog/tables/catalog_name.schema_name.table_name") + .method("GET"); + then.body(GET_TABLE_RESPONSE); + }) + .await; + + let list_schemas_response = client.list_schemas("catalog_name").await.unwrap(); + assert!(matches!( + list_schemas_response, + ListSchemasResponse::Success { .. } + )); + + let get_schema_response = client + .get_schema("catalog_name", "schema_name") + .await + .unwrap(); + assert!(matches!(get_schema_response, GetSchemaResponse::Success(_))); + + let get_table_response = client + .get_table("catalog_name", "schema_name", "table_name") + .await + .unwrap(); + assert!(matches!(get_table_response, GetTableResponse::Success(_))); + } } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 8174941f72..f499e76d06 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -20,7 +20,7 @@ delta_kernel.workspace = true # arrow arrow = { workspace = true } arrow-arith = { workspace = true } -arrow-array = { workspace = true , features = ["chrono-tz"]} +arrow-array = { workspace = true, features = ["chrono-tz"] } arrow-buffer = { workspace = true } arrow-cast = { workspace = true } arrow-ipc = { workspace = true } @@ -58,7 +58,7 @@ regex = { workspace = true } thiserror = { workspace = true } uuid = { workspace = true, features = ["serde", "v4"] } url = { workspace = true } -urlencoding = { workspace = true} +urlencoding = { workspace = true } # runtime async-trait = { workspace = true } @@ -81,7 +81,6 @@ dashmap = "6" errno = "0.3" either = "1.8" fix-hidden-lifetime-bug = "0.2" -#hyper = { version = "0.14", optional = true } indexmap = "2.2.1" itertools = "0.13" lazy_static = "1" @@ -99,19 +98,12 @@ z85 = "3.0.5" maplit = "1" sqlparser = { version = "0.52.0" } -# Unity -reqwest = { version = "0.12.9", default-features = false, features = [ - "rustls-tls", - "json", -], optional = true } - [dev-dependencies] criterion = "0.5" ctor = "0" deltalake-test = { path = "../test", features = ["datafusion"] } dotenvy = "0" fs_extra = "1.2.0" -#hyper = { version = "0.14", features = ["server"] } maplit = "1" pretty_assertions = "1.2.1" pretty_env_logger = "0.5.0" @@ -136,5 +128,4 @@ datafusion = [ ] datafusion-ext = ["datafusion"] json = ["parquet/json"] -python = ["arrow/pyarrow"] -unity-experimental = ["reqwest"] +python = ["arrow/pyarrow"] \ No newline at end of file diff --git a/crates/core/src/data_catalog/mod.rs b/crates/core/src/data_catalog/mod.rs index 5ae5e9aa23..fbb44d95c1 100644 --- a/crates/core/src/data_catalog/mod.rs +++ b/crates/core/src/data_catalog/mod.rs @@ -20,28 +20,6 @@ pub enum DataCatalogError { source: Box, }, - #[error("Request error: {source}")] - #[cfg(feature = "unity-experimental")] - /// Error from reqwest library - RequestError { - /// The underlying reqwest_middleware::Error - #[from] - source: reqwest::Error, - }, - - /// Error caused by missing environment variable for Unity Catalog. - #[cfg(feature = "unity-experimental")] - #[error("Missing Unity Catalog environment variable: {var_name}")] - MissingEnvVar { - /// Variable name - var_name: String, - }, - - /// Error caused by invalid access token value - #[cfg(feature = "unity-experimental")] - #[error("Invalid Databricks personal access token")] - InvalidAccessToken, - /// Error representing an invalid Data Catalog. #[error("This data catalog doesn't exist: {data_catalog}")] InvalidDataCatalog { @@ -58,16 +36,23 @@ pub enum DataCatalogError { /// configuration key key: String, }, + + #[error("Error in request: {source}")] + RequestError { + source: Box, + }, } /// Abstractions for data catalog for the Delta table. To add support for new cloud, simply implement this trait. #[async_trait::async_trait] pub trait DataCatalog: Send + Sync + Debug { + type Error; + /// Get the table storage location from the Data Catalog async fn get_table_storage_location( &self, catalog_id: Option, database_name: &str, table_name: &str, - ) -> Result; + ) -> Result; } diff --git a/crates/core/src/data_catalog/storage/mod.rs b/crates/core/src/data_catalog/storage/mod.rs index 110e4aa075..236caf79a8 100644 --- a/crates/core/src/data_catalog/storage/mod.rs +++ b/crates/core/src/data_catalog/storage/mod.rs @@ -88,9 +88,9 @@ impl ListingSchemaProvider { } } -// noramalizes a path fragment to be a valida table name in datafusion +// normalizes a path fragment to be a valida table name in datafusion // - removes some reserved characters (-, +, ., " ") -// - lowecase ascii +// - lowercase ascii fn normalize_table_name(path: &Path) -> Result { Ok(path .file_name() diff --git a/crates/core/src/delta_datafusion/expr.rs b/crates/core/src/delta_datafusion/expr.rs index c0e79ba490..b633cae141 100644 --- a/crates/core/src/delta_datafusion/expr.rs +++ b/crates/core/src/delta_datafusion/expr.rs @@ -234,6 +234,10 @@ impl ContextProvider for DeltaContextProvider<'_> { self.state.aggregate_functions().get(name).cloned() } + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + fn get_variable_type(&self, _var: &[String]) -> Option { unimplemented!() } @@ -242,10 +246,6 @@ impl ContextProvider for DeltaContextProvider<'_> { self.state.config_options() } - fn get_window_meta(&self, name: &str) -> Option> { - self.state.window_functions().get(name).cloned() - } - fn udf_names(&self) -> Vec { self.state.scalar_functions().keys().cloned().collect() } diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 1af1f566c5..6c96250288 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -1555,7 +1555,7 @@ fn join_batches_with_add_actions( } /// Determine which files contain a record that satisfies the predicate -pub(crate) async fn find_files_scan<'a>( +pub(crate) async fn find_files_scan( snapshot: &DeltaTableState, log_store: LogStoreRef, state: &SessionState, @@ -1668,7 +1668,7 @@ pub(crate) async fn scan_memory_table( } /// Finds files in a snapshot that match the provided predicate. -pub async fn find_files<'a>( +pub async fn find_files( snapshot: &DeltaTableState, log_store: LogStoreRef, state: &SessionState, diff --git a/crates/core/src/kernel/models/actions.rs b/crates/core/src/kernel/models/actions.rs index ef370b4956..119f561b80 100644 --- a/crates/core/src/kernel/models/actions.rs +++ b/crates/core/src/kernel/models/actions.rs @@ -187,9 +187,9 @@ impl Protocol { let mut converted_writer_features = configuration .iter() .filter(|(_, value)| { - value.as_ref().map_or(false, |v| { - v.to_ascii_lowercase().parse::().is_ok_and(|v| v) - }) + value + .as_ref() + .is_some_and(|v| v.to_ascii_lowercase().parse::().is_ok_and(|v| v)) }) .collect::>>() .keys() @@ -216,9 +216,9 @@ impl Protocol { let converted_reader_features = configuration .iter() .filter(|(_, value)| { - value.as_ref().map_or(false, |v| { - v.to_ascii_lowercase().parse::().is_ok_and(|v| v) - }) + value + .as_ref() + .is_some_and(|v| v.to_ascii_lowercase().parse::().is_ok_and(|v| v)) }) .map(|(key, _)| (*key).clone().into()) .filter(|v| !matches!(v, ReaderFeatures::Other(_))) diff --git a/crates/core/src/kernel/snapshot/mod.rs b/crates/core/src/kernel/snapshot/mod.rs index 25e11b88ca..a85087ea9b 100644 --- a/crates/core/src/kernel/snapshot/mod.rs +++ b/crates/core/src/kernel/snapshot/mod.rs @@ -416,7 +416,7 @@ impl EagerSnapshot { } /// Update the snapshot to the given version - pub async fn update<'a>( + pub async fn update( &mut self, log_store: Arc, target_version: Option, diff --git a/crates/core/src/operations/merge/filter.rs b/crates/core/src/operations/merge/filter.rs index 0745c55830..602df519a1 100644 --- a/crates/core/src/operations/merge/filter.rs +++ b/crates/core/src/operations/merge/filter.rs @@ -252,16 +252,13 @@ pub(crate) fn generalize_filter( } } Expr::InList(in_list) => { - let compare_expr = match generalize_filter( + let compare_expr = generalize_filter( *in_list.expr, partition_columns, source_name, target_name, placeholders, - ) { - Some(expr) => expr, - None => return None, // Return early - }; + )?; let mut list_expr = Vec::new(); for item in in_list.list.into_iter() { diff --git a/crates/core/src/operations/transaction/mod.rs b/crates/core/src/operations/transaction/mod.rs index 88b28a8627..6d80d858b0 100644 --- a/crates/core/src/operations/transaction/mod.rs +++ b/crates/core/src/operations/transaction/mod.rs @@ -649,7 +649,7 @@ pub struct PostCommit<'a> { table_data: Option<&'a dyn TableReference>, } -impl<'a> PostCommit<'a> { +impl PostCommit<'_> { /// Runs the post commit activities async fn run_post_commit_hook(&self) -> DeltaResult { if let Some(table) = self.table_data { diff --git a/crates/core/src/writer/stats.rs b/crates/core/src/writer/stats.rs index 4fe448ea76..10260b8364 100644 --- a/crates/core/src/writer/stats.rs +++ b/crates/core/src/writer/stats.rs @@ -352,6 +352,7 @@ impl StatsScalar { // the precision - scale range take the next smaller (by magnitude) value val = f64::from_bits(val.to_bits() - 1); } + Ok(Self::Decimal(val)) } (Statistics::FixedLenByteArray(v), Some(LogicalType::Uuid)) => { diff --git a/crates/deltalake/Cargo.toml b/crates/deltalake/Cargo.toml index 476f0b5d60..c760a55971 100644 --- a/crates/deltalake/Cargo.toml +++ b/crates/deltalake/Cargo.toml @@ -22,6 +22,7 @@ deltalake-azure = { version = "0.5.0", path = "../azure", optional = true } deltalake-gcp = { version = "0.6.0", path = "../gcp", optional = true } deltalake-hdfs = { version = "0.6.0", path = "../hdfs", optional = true } deltalake-catalog-glue = { version = "0.6.0", path = "../catalog-glue", optional = true } +deltalake-catalog-unity = { version = "0.6.0", path = "../catalog-unity", optional = true } [features] # All of these features are just reflected into the core crate until that @@ -37,7 +38,7 @@ json = ["deltalake-core/json"] python = ["deltalake-core/python"] s3-native-tls = ["deltalake-aws/native-tls"] s3 = ["deltalake-aws/rustls"] -unity-experimental = ["deltalake-core/unity-experimental"] +unity-experimental = ["deltalake-catalog-unity"] [dev-dependencies] tokio = { version = "1", features = ["macros", "rt-multi-thread"] } From 92b45a01a3c8acc2632180eadae58e82eb4c9646 Mon Sep 17 00:00:00 2001 From: Stephen Carman Date: Sun, 15 Dec 2024 12:17:45 -0500 Subject: [PATCH 3/4] feat: move unity catalog integration into its own crate Signed-off-by: Stephen Carman --- crates/core/src/delta_datafusion/mod.rs | 2 +- crates/core/src/operations/load_cdf.rs | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index 6c96250288..e692dd054b 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -229,7 +229,7 @@ fn _arrow_schema(snapshot: &Snapshot, wrap_partitions: bool) -> DeltaResult( snapshot: &'a EagerSnapshot, filters: &[Expr], -) -> DeltaResult + 'a + use<'a>> { +) -> DeltaResult + 'a> { if let Some(Some(predicate)) = (!filters.is_empty()).then_some(conjunction(filters.iter().cloned())) { diff --git a/crates/core/src/operations/load_cdf.rs b/crates/core/src/operations/load_cdf.rs index a63b5182b2..3d5bed2d26 100644 --- a/crates/core/src/operations/load_cdf.rs +++ b/crates/core/src/operations/load_cdf.rs @@ -173,9 +173,7 @@ impl CdfLoadBuilder { return if self.allow_out_of_range { Ok((change_files, add_files, remove_files)) } else { - Err(DeltaTableError::ChangeDataTimestampGreaterThanCommit { - ending_timestamp: ending_timestamp, - }) + Err(DeltaTableError::ChangeDataTimestampGreaterThanCommit { ending_timestamp }) }; } } From a3f502bc0b69ed32f110711467500c27cd5e3fbb Mon Sep 17 00:00:00 2001 From: Stephen Carman Date: Sun, 15 Dec 2024 12:17:45 -0500 Subject: [PATCH 4/4] feat: move unity catalog integration into its own crate Signed-off-by: Stephen Carman --- crates/aws/src/lib.rs | 4 +++- crates/aws/tests/common.rs | 10 ++++++---- crates/catalog-unity/src/lib.rs | 34 +++++++++++++++------------------ 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/crates/aws/src/lib.rs b/crates/aws/src/lib.rs index 981d368aa8..c062e47334 100644 --- a/crates/aws/src/lib.rs +++ b/crates/aws/src/lib.rs @@ -771,7 +771,9 @@ mod tests { let factory = S3LogStoreFactory::default(); let store = InMemory::new(); let url = Url::parse("s3://test-bucket").unwrap(); - unsafe { std::env::remove_var(crate::constants::AWS_S3_LOCKING_PROVIDER); } + unsafe { + std::env::remove_var(crate::constants::AWS_S3_LOCKING_PROVIDER); + } let logstore = factory .with_options(Arc::new(store), &url, &StorageOptions::from(HashMap::new())) .unwrap(); diff --git a/crates/aws/tests/common.rs b/crates/aws/tests/common.rs index 8f4adb7523..e32522e2d3 100644 --- a/crates/aws/tests/common.rs +++ b/crates/aws/tests/common.rs @@ -46,11 +46,13 @@ impl StorageIntegration for S3Integration { format!("delta_log_it_{}", random::()), ); match std::env::var(s3_constants::AWS_ENDPOINT_URL).ok() { - Some(endpoint_url) if endpoint_url.to_lowercase() == "none" => { - unsafe { std::env::remove_var(s3_constants::AWS_ENDPOINT_URL) } - } + Some(endpoint_url) if endpoint_url.to_lowercase() == "none" => unsafe { + std::env::remove_var(s3_constants::AWS_ENDPOINT_URL) + }, Some(_) => (), - None => unsafe { std::env::set_var(s3_constants::AWS_ENDPOINT_URL, "http://localhost:4566") }, + None => unsafe { + std::env::set_var(s3_constants::AWS_ENDPOINT_URL, "http://localhost:4566") + }, } set_env_if_not_set(s3_constants::AWS_ACCESS_KEY_ID, "deltalake"); set_env_if_not_set(s3_constants::AWS_SECRET_ACCESS_KEY, "weloverust"); diff --git a/crates/catalog-unity/src/lib.rs b/crates/catalog-unity/src/lib.rs index b8ccd23865..f5b8d1d08a 100644 --- a/crates/catalog-unity/src/lib.rs +++ b/crates/catalog-unity/src/lib.rs @@ -19,8 +19,8 @@ pub mod client; pub mod credential; #[cfg(feature = "datafusion")] pub mod datafusion; -pub mod models; pub mod error; +pub mod models; /// Possible errors from the unity-catalog/tables API call #[derive(thiserror::Error, Debug)] @@ -330,7 +330,7 @@ impl UnityCatalogBuilder { } /// Hydrate builder from key value pairs - pub fn try_with_options, impl Into)>>( + pub fn try_with_options, impl Into)>>( mut self, options: I, ) -> DataCatalogResult { @@ -496,7 +496,7 @@ impl UnityCatalog { /// all catalogs will be retrieved. Otherwise, only catalogs owned by the caller /// (or for which the caller has the USE_CATALOG privilege) will be retrieved. /// There is no guarantee of a specific ordering of the elements in the array. - pub async fn list_catalogs(&self) -> DataCatalogResult { + pub async fn list_catalogs(&self) -> Result { let token = self.get_credential().await?; // https://docs.databricks.com/api-explorer/workspace/schemas/list let resp = self @@ -504,9 +504,8 @@ impl UnityCatalog { .get(format!("{}/catalogs", self.catalog_url())) .header(AUTHORIZATION, token) .send() - .await - .map_err(UnityCatalogError::from)?; - Ok(resp.json().await.map_err(UnityCatalogError::from)?) + .await?; + Ok(resp.json().await?) } /// List all schemas for a catalog in the metastore. @@ -521,7 +520,7 @@ impl UnityCatalog { pub async fn list_schemas( &self, catalog_name: impl AsRef, - ) -> DataCatalogResult { + ) -> Result { let token = self.get_credential().await?; // https://docs.databricks.com/api-explorer/workspace/schemas/list let resp = self @@ -530,9 +529,8 @@ impl UnityCatalog { .header(AUTHORIZATION, token) .query(&[("catalog_name", catalog_name.as_ref())]) .send() - .await - .map_err(UnityCatalogError::from)?; - Ok(resp.json().await.map_err(UnityCatalogError::from)?) + .await?; + Ok(resp.json().await?) } /// Gets the specified schema within the metastore.# @@ -543,7 +541,7 @@ impl UnityCatalog { &self, catalog_name: impl AsRef, schema_name: impl AsRef, - ) -> DataCatalogResult { + ) -> Result { let token = self.get_credential().await?; // https://docs.databricks.com/api-explorer/workspace/schemas/get let resp = self @@ -556,9 +554,8 @@ impl UnityCatalog { )) .header(AUTHORIZATION, token) .send() - .await - .map_err(UnityCatalogError::from)?; - Ok(resp.json().await.map_err(UnityCatalogError::from)?) + .await?; + Ok(resp.json().await?) } /// Gets an array of summaries for tables for a schema and catalog within the metastore. @@ -576,7 +573,7 @@ impl UnityCatalog { &self, catalog_name: impl AsRef, schema_name_pattern: impl AsRef, - ) -> DataCatalogResult { + ) -> Result { let token = self.get_credential().await?; // https://docs.databricks.com/api-explorer/workspace/tables/listsummaries let resp = self @@ -588,10 +585,9 @@ impl UnityCatalog { ]) .header(AUTHORIZATION, token) .send() - .await - .map_err(UnityCatalogError::from)?; + .await?; - Ok(resp.json().await.map_err(UnityCatalogError::from)?) + Ok(resp.json().await?) } /// Gets a table from the metastore for a specific catalog and schema. @@ -649,7 +645,7 @@ impl DataCatalog for UnityCatalog { error_code: err.error_code, message: err.message, } - .into()), + .into()), } } }