Skip to content

Commit

Permalink
Added an error type to the DataCatalog trait to not require all the c…
Browse files Browse the repository at this point in the history
…onversion between error types.

Signed-off-by: Stephen Carman <[email protected]>
  • Loading branch information
hntd187 committed Dec 13, 2024
1 parent 8e697d3 commit aba99ed
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 45 deletions.
4 changes: 3 additions & 1 deletion crates/aws/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ impl S3StorageOptions {

fn ensure_env_var(map: &HashMap<String, String>, key: &str) {
if let Some(val) = str_option(map, key) {
std::env::set_var(key, val);
unsafe {
std::env::set_var(key, val);
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/catalog-glue/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ const PLACEHOLDER_SUFFIX: &str = "-__PLACEHOLDER__";

#[async_trait::async_trait]
impl DataCatalog for GlueDataCatalog {
type Error = DataCatalogError;

/// Get the table storage location from the Glue Data Catalog
async fn get_table_storage_location(
&self,
Expand Down
10 changes: 5 additions & 5 deletions crates/catalog-unity/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ repository.workspace = true
rust-version.workspace = true

[dependencies]
async-trait = { workspace = true }
async-trait.workspace = true
tokio.workspace = true
serde.workspace = true
serde_json.workspace = true
thiserror.workspace = true
deltalake-core = { version = "0.22", path = "../core" }
thiserror = { workspace = true }
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json", "http2"] }
reqwest-retry = "0.7"
reqwest-middleware = "0.4.0"
rand = "0.8"
futures = "0.3"
chrono = "0.4"
tokio.workspace = true
serde.workspace = true
serde_json.workspace = true
dashmap = "6"
tracing = "0.1"
datafusion = { version = "43", optional = true }
Expand Down
10 changes: 5 additions & 5 deletions crates/catalog-unity/src/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub trait TokenCredential: std::fmt::Debug + Send + Sync + 'static {
async fn fetch_token(
&self,
client: &ClientWithMiddleware,
) -> DataCatalogResult<TemporaryToken<String>>;
) -> Result<TemporaryToken<String>, UnityCatalogError>;
}

/// Provides credentials for use when signing requests
Expand Down Expand Up @@ -95,7 +95,7 @@ impl TokenCredential for ClientSecretOAuthProvider {
async fn fetch_token(
&self,
client: &ClientWithMiddleware,
) -> DataCatalogResult<TemporaryToken<String>> {
) -> Result<TemporaryToken<String>, UnityCatalogError> {
let response: TokenResponse = client
.request(Method::POST, &self.token_url)
.header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
Expand Down Expand Up @@ -168,7 +168,7 @@ impl TokenCredential for AzureCliCredential {
async fn fetch_token(
&self,
_client: &ClientWithMiddleware,
) -> DataCatalogResult<TemporaryToken<String>> {
) -> Result<TemporaryToken<String>, UnityCatalogError> {
// on window az is a cmd and it should be called like this
// see https://doc.rust-lang.org/nightly/std/process/struct.Command.html
let program = if cfg!(target_os = "windows") {
Expand Down Expand Up @@ -281,7 +281,7 @@ impl TokenCredential for WorkloadIdentityOAuthProvider {
async fn fetch_token(
&self,
client: &ClientWithMiddleware,
) -> DataCatalogResult<TemporaryToken<String>> {
) -> Result<TemporaryToken<String>, UnityCatalogError> {
let token_str = std::fs::read_to_string(&self.federated_token_file)
.map_err(|_| UnityCatalogError::FederatedTokenFile)?;

Expand Down Expand Up @@ -371,7 +371,7 @@ impl TokenCredential for ImdsManagedIdentityOAuthProvider {
async fn fetch_token(
&self,
_client: &ClientWithMiddleware,
) -> DataCatalogResult<TemporaryToken<String>> {
) -> Result<TemporaryToken<String>, UnityCatalogError> {
let resource_scope = format!("{}/.default", DATABRICKS_RESOURCE_SCOPE);
let mut query_items = vec![
("api-version", MSI_API_VERSION),
Expand Down
49 changes: 18 additions & 31 deletions crates/catalog-unity/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
//! Databricks Unity Catalog.
//!
//! This module is gated behind the "unity-experimental" feature.
use std::str::FromStr;

use reqwest::header::{HeaderValue, AUTHORIZATION};
use reqwest::header::{HeaderValue, InvalidHeaderValue, AUTHORIZATION};

use crate::credential::{AzureCliCredential, ClientSecretOAuthProvider, CredentialProvider};
use crate::models::{
Expand All @@ -25,7 +23,7 @@ pub mod models;

/// Possible errors from the unity-catalog/tables API call
#[derive(thiserror::Error, Debug)]
enum UnityCatalogError {
pub enum UnityCatalogError {
#[error("GET request error: {source}")]
/// Error from reqwest library
RequestError {
Expand All @@ -50,9 +48,11 @@ enum UnityCatalogError {
message: String,
},

/// Unknown configuration key
#[error("Unknown configuration key: {catalog} in catalog: {key}")]
UnknownConfigKey { catalog: &'static str, key: String },
#[error("Invalid token for auth header: {header_error}")]
InvalidHeader {
#[from]
header_error: InvalidHeaderValue,
},

/// Unknown configuration key
#[error("Missing configuration key: {0}")]
Expand All @@ -75,9 +75,6 @@ enum UnityCatalogError {
impl From<UnityCatalogError> for DataCatalogError {
fn from(value: UnityCatalogError) -> Self {
match value {
UnityCatalogError::UnknownConfigKey { catalog, key } => {
DataCatalogError::UnknownConfigKey { catalog, key }
}
_ => DataCatalogError::Generic {
catalog: "Unity",
source: Box::new(value),
Expand Down Expand Up @@ -227,7 +224,7 @@ impl FromStr for UnityCatalogConfigKey {
Ok(UnityCatalogConfigKey::WorkspaceUrl)
}
_ => Err(DataCatalogError::UnknownConfigKey {
catalog: "",
catalog: "unity",
key: s.to_string(),
}),
}
Expand Down Expand Up @@ -471,31 +468,21 @@ pub struct UnityCatalog {
}

impl UnityCatalog {
async fn get_credential(&self) -> DataCatalogResult<HeaderValue> {
async fn get_credential(&self) -> Result<HeaderValue, UnityCatalogError> {
match &self.credential {
CredentialProvider::BearerToken(token) => {
// we do the conversion to a HeaderValue here, since it is fallible
// we do the conversion to a HeaderValue here, since it is fallible,
// and we want to use it in an infallible function
HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| {
DataCatalogError::Generic {
catalog: "Unity",
source: Box::new(err),
}
})
Ok(HeaderValue::from_str(&format!("Bearer {token}"))?)
}
CredentialProvider::TokenCredential(cache, cred) => {
let token = cache
.get_or_insert_with(|| cred.fetch_token(&self.client))
.await?;

// we do the conversion to a HeaderValue here, since it is fallible
// we do the conversion to a HeaderValue here, since it is fallible,
// and we want to use it in an infallible function
HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| {
DataCatalogError::Generic {
catalog: "Unity",
source: Box::new(err),
}
})
Ok(HeaderValue::from_str(&format!("Bearer {token}"))?)
}
}
}
Expand Down Expand Up @@ -618,7 +605,7 @@ impl UnityCatalog {
catalog_id: impl AsRef<str>,
database_name: impl AsRef<str>,
table_name: impl AsRef<str>,
) -> DataCatalogResult<GetTableResponse> {
) -> Result<GetTableResponse, UnityCatalogError> {
let token = self.get_credential().await?;
// https://docs.databricks.com/api-explorer/workspace/tables/get
let resp = self
Expand All @@ -632,22 +619,22 @@ impl UnityCatalog {
))
.header(AUTHORIZATION, token)
.send()
.await
.map_err(UnityCatalogError::from)?;
.await?;

Ok(resp.json().await.map_err(UnityCatalogError::from)?)
Ok(resp.json().await?)
}
}

#[async_trait::async_trait]
impl DataCatalog for UnityCatalog {
type Error = UnityCatalogError;
/// Get the table storage location from the UnityCatalog
async fn get_table_storage_location(
&self,
catalog_id: Option<String>,
database_name: &str,
table_name: &str,
) -> Result<String, DataCatalogError> {
) -> Result<String, UnityCatalogError> {
match self
.get_table(
catalog_id.unwrap_or("main".into()),
Expand Down
4 changes: 3 additions & 1 deletion crates/core/src/data_catalog/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ pub enum DataCatalogError {
/// Abstractions for data catalog for the Delta table. To add support for new cloud, simply implement this trait.
#[async_trait::async_trait]
pub trait DataCatalog: Send + Sync + Debug {
type Error;

/// Get the table storage location from the Data Catalog
async fn get_table_storage_location(
&self,
catalog_id: Option<String>,
database_name: &str,
table_name: &str,
) -> Result<String, DataCatalogError>;
) -> Result<String, Self::Error>;
}
4 changes: 2 additions & 2 deletions crates/core/src/data_catalog/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ impl ListingSchemaProvider {
}
}

// noramalizes a path fragment to be a valida table name in datafusion
// normalizes a path fragment to be a valida table name in datafusion
// - removes some reserved characters (-, +, ., " ")
// - lowecase ascii
// - lowercase ascii
fn normalize_table_name(path: &Path) -> Result<String, DataFusionError> {
Ok(path
.file_name()
Expand Down

0 comments on commit aba99ed

Please sign in to comment.