diff --git a/Cargo.lock b/Cargo.lock index 2a9759c0d46f..0af6c89d92c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3051,6 +3051,7 @@ dependencies = [ "polars-schema", "polars-time", "polars-utils", + "pyo3", "rayon", "regex", "reqwest", @@ -3400,6 +3401,7 @@ dependencies = [ "num-traits", "once_cell", "polars-error", + "pyo3", "rand", "raw-cpuid", "rayon", diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index c3ed9b93bd5c..2e0a51acc9e4 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -37,6 +37,7 @@ num-traits = { workspace = true } object_store = { workspace = true, optional = true } once_cell = { workspace = true } percent-encoding = { workspace = true } +pyo3 = { workspace = true, optional = true } rayon = { workspace = true } regex = { workspace = true } reqwest = { workspace = true, optional = true } @@ -129,7 +130,7 @@ gcp = ["object_store/gcp", "cloud"] http = ["object_store/http", "cloud"] temporal = ["dtype-datetime", "dtype-date", "dtype-time"] simd = [] -python = ["polars-error/python"] +python = ["pyo3", "polars-error/python", "polars-utils/python"] [package.metadata.docs.rs] all-features = true diff --git a/crates/polars-io/src/cloud/credential_provider.rs b/crates/polars-io/src/cloud/credential_provider.rs new file mode 100644 index 000000000000..8926c4d0a835 --- /dev/null +++ b/crates/polars-io/src/cloud/credential_provider.rs @@ -0,0 +1,679 @@ +use std::fmt::Debug; +use std::future::Future; +use std::hash::Hash; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use async_trait::async_trait; +#[cfg(feature = "aws")] +pub use object_store::aws::AwsCredential; +#[cfg(feature = "azure")] +pub use object_store::azure::AzureCredential; +#[cfg(feature = "gcp")] +pub use object_store::gcp::GcpCredential; +use polars_core::config; +use polars_error::{polars_bail, PolarsResult}; +#[cfg(feature = "python")] +use polars_utils::python_function::PythonFunction; +#[cfg(feature = "python")] +use python_impl::PythonCredentialProvider; + +#[derive(Clone, Debug, PartialEq, Hash, Eq)] +pub enum PlCredentialProvider { + /// Prefer using [`PlCredentialProvider::from_func`] instead of constructing this directly + Function(CredentialProviderFunction), + #[cfg(feature = "python")] + Python(python_impl::PythonCredentialProvider), +} + +impl PlCredentialProvider { + /// Accepts a function that returns (credential, expiry time as seconds since UNIX_EPOCH) + /// + /// This functionality is unstable. + pub fn from_func( + // Internal notes + // * This function is exposed as the Rust API for `PlCredentialProvider` + func: impl Fn() -> Pin< + Box> + Send + Sync>, + > + Send + + Sync + + 'static, + ) -> Self { + Self::Function(CredentialProviderFunction(Arc::new(func))) + } + + #[cfg(feature = "python")] + pub fn from_python_func(func: PythonFunction) -> Self { + Self::Python(python_impl::PythonCredentialProvider(Arc::new(func))) + } + + #[cfg(feature = "python")] + pub fn from_python_func_object(func: pyo3::PyObject) -> Self { + Self::Python(python_impl::PythonCredentialProvider(Arc::new( + PythonFunction(func), + ))) + } +} + +pub enum ObjectStoreCredential { + #[cfg(feature = "aws")] + Aws(Arc), + #[cfg(feature = "azure")] + Azure(Arc), + #[cfg(feature = "gcp")] + Gcp(Arc), + /// For testing purposes + None, +} + +impl ObjectStoreCredential { + fn variant_name(&self) -> &'static str { + match self { + #[cfg(feature = "aws")] + Self::Aws(_) => "Aws", + #[cfg(feature = "azure")] + Self::Azure(_) => "Azure", + #[cfg(feature = "gcp")] + Self::Gcp(_) => "Gcp", + Self::None => "None", + } + } + + fn panic_type_mismatch(&self, expected: &str) { + panic!( + "impl error: credential type mismatch: expected {}, got {} instead", + expected, + self.variant_name() + ) + } + + #[cfg(feature = "aws")] + fn unwrap_aws(self) -> Arc { + let Self::Aws(v) = self else { + self.panic_type_mismatch("aws"); + unreachable!() + }; + v + } + + #[cfg(feature = "azure")] + fn unwrap_azure(self) -> Arc { + let Self::Azure(v) = self else { + self.panic_type_mismatch("azure"); + unreachable!() + }; + v + } + + #[cfg(feature = "gcp")] + fn unwrap_gcp(self) -> Arc { + let Self::Gcp(v) = self else { + self.panic_type_mismatch("gcp"); + unreachable!() + }; + v + } +} + +pub trait IntoCredentialProvider: Sized { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + unimplemented!() + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + unimplemented!() + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + unimplemented!() + } +} + +impl IntoCredentialProvider for PlCredentialProvider { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + match self { + Self::Function(v) => v.into_aws_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_aws_provider(), + } + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + match self { + Self::Function(v) => v.into_azure_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_azure_provider(), + } + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + match self { + Self::Function(v) => v.into_gcp_provider(), + #[cfg(feature = "python")] + Self::Python(v) => v.into_gcp_provider(), + } + } +} + +type CredentialProviderFunctionImpl = Arc< + dyn Fn() -> Pin< + Box> + Send + Sync>, + > + Send + + Sync, +>; + +/// Wrapper that implements [`IntoCredentialProvider`], [`Debug`], [`PartialEq`], [`Hash`] etc. +#[derive(Clone)] +pub struct CredentialProviderFunction(CredentialProviderFunctionImpl); + +macro_rules! build_to_object_store_err { + ($s:expr) => {{ + fn to_object_store_err( + e: impl std::error::Error + Send + Sync + 'static, + ) -> object_store::Error { + object_store::Error::Generic { + store: $s, + source: Box::new(e), + } + } + + to_object_store_err + }}; +} + +impl IntoCredentialProvider for CredentialProviderFunction { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::aws::AwsCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0 .0().await?; + PolarsResult::Ok((creds.unwrap_aws(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-aws")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(AwsCredential { + key_id: String::new(), + secret_key: String::new(), + token: None, + })), + )) + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::azure::AzureCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0 .0().await?; + PolarsResult::Ok((creds.unwrap_azure(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-azure")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(AzureCredential::BearerToken(String::new()))), + )) + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + #[derive(Debug)] + struct S( + CredentialProviderFunction, + FetchedCredentialsCache>, + ); + + #[async_trait] + impl object_store::CredentialProvider for S { + type Credential = object_store::gcp::GcpCredential; + + async fn get_credential(&self) -> object_store::Result> { + self.1 + .get_maybe_update(async { + let (creds, expiry) = self.0 .0().await?; + PolarsResult::Ok((creds.unwrap_gcp(), expiry)) + }) + .await + .map_err(build_to_object_store_err!("credential-provider-gcp")) + } + } + + Arc::new(S( + self, + FetchedCredentialsCache::new(Arc::new(GcpCredential { + bearer: String::new(), + })), + )) + } +} + +impl Debug for CredentialProviderFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "credential provider function at 0x{:016x}", + self.0.as_ref() as *const _ as *const () as usize + ) + } +} + +impl Eq for CredentialProviderFunction {} + +impl PartialEq for CredentialProviderFunction { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl Hash for CredentialProviderFunction { + fn hash(&self, state: &mut H) { + state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) + } +} + +#[cfg(feature = "serde")] +impl<'de> serde::Deserialize<'de> for PlCredentialProvider { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[cfg(feature = "python")] + { + Ok(Self::Python(PythonCredentialProvider::deserialize( + deserializer, + )?)) + } + #[cfg(not(feature = "python"))] + { + use serde::de::Error; + Err(D::Error::custom("cannot deserialize PlCredentialProvider")) + } + } +} + +#[cfg(feature = "serde")] +impl serde::Serialize for PlCredentialProvider { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::Error; + + // TODO: + // * Add magic bytes here to indicate a python function + // * Check the Python version on deserialize + #[cfg(feature = "python")] + if let PlCredentialProvider::Python(v) = self { + return v.serialize(serializer); + } + + Err(S::Error::custom(format!("cannot serialize {:?}", self))) + } +} + +/// Avoids calling the credential provider function if we have not yet passed the expiry time. +#[derive(Debug)] +struct FetchedCredentialsCache(tokio::sync::Mutex<(C, u64)>); + +impl FetchedCredentialsCache { + fn new(init_creds: C) -> Self { + Self(tokio::sync::Mutex::new((init_creds, 0))) + } + + async fn get_maybe_update( + &self, + // Taking an `impl Future` here allows us to potentially avoid a `Box::pin` allocation from + // a `Fn() -> Pin>` by having it wrapped in an `async { f() }` block. We + // will not poll that block if the credentials have not yet expired. + update_func: impl Future>, + ) -> PolarsResult { + let verbose = config::verbose(); + let mut inner = self.0.lock().await; + let (last_fetched_credentials, last_fetched_expiry) = &mut *inner; + + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + // Ensure the credential is valid for at least this many seconds to + // accommodate for latency. + const REQUEST_TIME_BUFFER: u64 = 7; + + if last_fetched_expiry.saturating_sub(current_time) < REQUEST_TIME_BUFFER { + if verbose { + eprintln!( + "[FetchedCredentialsCache]: Call update_func: current_time = {},\ + last_fetched_expiry = {}", + current_time, *last_fetched_expiry + ) + } + let (credentials, expiry) = update_func.await?; + + *last_fetched_credentials = credentials; + *last_fetched_expiry = expiry; + + if expiry < current_time && expiry != 0 { + polars_bail!( + ComputeError: + "credential expiry time {} is older than system time {} \ + by {} seconds", + expiry, + current_time, + current_time - expiry + ) + } + + if verbose { + eprintln!( + "[FetchedCredentialsCache]: Finish update_func: \ + new expiry = {} (in {} seconds)", + *last_fetched_expiry, + last_fetched_expiry.saturating_sub( + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + ), + ) + } + } + + Ok(last_fetched_credentials.clone()) + } +} + +#[cfg(feature = "python")] +mod python_impl { + use std::hash::Hash; + use std::sync::Arc; + + use polars_error::PolarsError; + use polars_utils::python_function::PythonFunction; + use pyo3::exceptions::PyValueError; + use pyo3::pybacked::PyBackedStr; + use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods}; + use pyo3::Python; + + use super::IntoCredentialProvider; + + #[derive(Clone, Debug)] + #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] + pub struct PythonCredentialProvider(pub(super) Arc); + + impl IntoCredentialProvider for PythonCredentialProvider { + #[cfg(feature = "aws")] + fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { + use polars_error::{to_compute_err, PolarsResult}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + CredentialProviderFunction(Arc::new(move || { + let func = self.0.clone(); + Box::pin(async move { + let mut credentials = object_store::aws::AwsCredential { + key_id: String::new(), + secret_key: String::new(), + token: None, + }; + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::>()?; + + match k.as_ref() { + "aws_access_key_id" => { + credentials.key_id = v.ok_or_else(|| { + PyValueError::new_err("aws_access_key_id was None") + })?; + }, + "aws_secret_access_key" => { + credentials.secret_key = v.ok_or_else(|| { + PyValueError::new_err("aws_secret_access_key was None") + })? + }, + "aws_session_token" => credentials.token = v, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for aws: {}, \ + valid configuration keys are: \ + {}, {}, {}", + v, + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token" + ))) + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + if credentials.key_id.is_empty() { + return Err(PolarsError::ComputeError( + "aws_access_key_id was empty or not given".into(), + )); + } + + if credentials.secret_key.is_empty() { + return Err(PolarsError::ComputeError( + "aws_secret_access_key was empty or not given".into(), + )); + } + + PolarsResult::Ok((ObjectStoreCredential::Aws(Arc::new(credentials)), expiry)) + }) + })) + .into_aws_provider() + } + + #[cfg(feature = "azure")] + fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider { + use polars_error::{to_compute_err, PolarsResult}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + CredentialProviderFunction(Arc::new(move || { + let func = self.0.clone(); + Box::pin(async move { + let mut credentials = + object_store::azure::AzureCredential::BearerToken(String::new()); + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::()?; + + // We only support bearer for now + match k.as_ref() { + "bearer_token" => { + credentials = + object_store::azure::AzureCredential::BearerToken(v) + }, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for azure: {}, \ + valid configuration keys are: {}", + v, "bearer_token", + ))) + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + let object_store::azure::AzureCredential::BearerToken(bearer) = &credentials + else { + unreachable!() + }; + + if bearer.is_empty() { + return Err(PolarsError::ComputeError( + "bearer was empty or not given".into(), + )); + } + + PolarsResult::Ok((ObjectStoreCredential::Azure(Arc::new(credentials)), expiry)) + }) + })) + .into_azure_provider() + } + + #[cfg(feature = "gcp")] + fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider { + use polars_error::{to_compute_err, PolarsResult}; + + use crate::cloud::credential_provider::{ + CredentialProviderFunction, ObjectStoreCredential, + }; + + CredentialProviderFunction(Arc::new(move || { + let func = self.0.clone(); + Box::pin(async move { + let mut credentials = object_store::gcp::GcpCredential { + bearer: String::new(), + }; + + let expiry = Python::with_gil(|py| { + let v = func.0.call0(py)?.into_bound(py); + let (storage_options, expiry) = + v.extract::<(pyo3::Bound<'_, PyDict>, Option)>()?; + + for (k, v) in storage_options.iter() { + let k = k.extract::()?; + let v = v.extract::()?; + + match k.as_ref() { + "bearer_token" => credentials.bearer = v, + v => { + return pyo3::PyResult::Err(PyValueError::new_err(format!( + "unknown configuration key for gcp: {}, \ + valid configuration keys are: {}", + v, "bearer_token", + ))) + }, + } + } + + pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX)) + }) + .map_err(to_compute_err)?; + + if credentials.bearer.is_empty() { + return Err(PolarsError::ComputeError( + "bearer was empty or not given".into(), + )); + } + + PolarsResult::Ok((ObjectStoreCredential::Gcp(Arc::new(credentials)), expiry)) + }) + })) + .into_gcp_provider() + } + } + + impl Eq for PythonCredentialProvider {} + + impl PartialEq for PythonCredentialProvider { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } + } + + impl Hash for PythonCredentialProvider { + fn hash(&self, state: &mut H) { + // # Safety + // * Inner is an `Arc` + // * Visibility is limited to super + // * No code in `mod python_impl` or `super` mutates the Arc inner. + state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) + } + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "serde")] + #[allow(clippy::redundant_pattern_matching)] + #[test] + fn test_serde() { + use super::*; + + assert!(matches!( + serde_json::to_string(&Some(PlCredentialProvider::from_func(|| { + Box::pin(core::future::ready(PolarsResult::Ok(( + ObjectStoreCredential::None, + 0, + )))) + }))), + Err(_) + )); + + assert!(matches!( + serde_json::to_string(&Option::::None), + Ok(String { .. }) + )); + + assert!(matches!( + serde_json::from_str::>( + serde_json::to_string(&Option::::None) + .unwrap() + .as_str() + ), + Ok(None) + )); + } +} diff --git a/crates/polars-io/src/cloud/mod.rs b/crates/polars-io/src/cloud/mod.rs index b41f7d45cf21..7ae2d99444a7 100644 --- a/crates/polars-io/src/cloud/mod.rs +++ b/crates/polars-io/src/cloud/mod.rs @@ -19,3 +19,6 @@ pub use object_store_setup::*; pub use options::*; #[cfg(feature = "cloud")] pub use polars_object_store::*; + +#[cfg(feature = "cloud")] +pub mod credential_provider; diff --git a/crates/polars-io/src/cloud/options.rs b/crates/polars-io/src/cloud/options.rs index efaab673f634..9549b837f06d 100644 --- a/crates/polars-io/src/cloud/options.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -36,6 +36,8 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "cloud")] use url::Url; +#[cfg(feature = "cloud")] +use super::credential_provider::PlCredentialProvider; #[cfg(feature = "file_cache")] use crate::file_cache::get_env_file_cache_ttl; #[cfg(feature = "aws")] @@ -75,6 +77,8 @@ pub struct CloudOptions { #[cfg(feature = "file_cache")] pub file_cache_ttl: u64, pub(crate) config: Option, + #[cfg(feature = "cloud")] + pub(crate) credential_provider: Option, } impl Default for CloudOptions { @@ -84,6 +88,8 @@ impl Default for CloudOptions { #[cfg(feature = "file_cache")] file_cache_ttl: get_env_file_cache_ttl(), config: None, + #[cfg(feature = "cloud")] + credential_provider: Default::default(), } } } @@ -248,6 +254,15 @@ impl CloudOptions { self } + #[cfg(feature = "cloud")] + pub fn with_credential_provider( + mut self, + credential_provider: Option, + ) -> Self { + self.credential_provider = credential_provider; + self + } + /// Set the configuration for AWS connections. This is the preferred API from rust. #[cfg(feature = "aws")] pub fn with_aws)>>( @@ -263,6 +278,8 @@ impl CloudOptions { /// Build the [`object_store::ObjectStore`] implementation for AWS. #[cfg(feature = "aws")] pub async fn build_aws(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + let mut builder = AmazonS3Builder::from_env().with_url(url); if let Some(options) = &self.config { let CloudConfig::Aws(options) = options else { @@ -346,11 +363,17 @@ impl CloudOptions { }; }; - builder + let builder = builder .with_client_options(get_client_options()) - .with_retry(get_retry_config(self.max_retries)) - .build() - .map_err(to_compute_err) + .with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = self.credential_provider.clone() { + builder.with_credentials(v.into_aws_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) } /// Set the configuration for Azure connections. This is the preferred API from rust. @@ -368,6 +391,8 @@ impl CloudOptions { /// Build the [`object_store::ObjectStore`] implementation for Azure. #[cfg(feature = "azure")] pub fn build_azure(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + let mut builder = MicrosoftAzureBuilder::from_env(); if let Some(options) = &self.config { let CloudConfig::Azure(options) = options else { @@ -378,12 +403,18 @@ impl CloudOptions { } } - builder + let builder = builder .with_client_options(get_client_options()) .with_url(url) - .with_retry(get_retry_config(self.max_retries)) - .build() - .map_err(to_compute_err) + .with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = self.credential_provider.clone() { + builder.with_credentials(v.into_azure_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) } /// Set the configuration for GCP connections. This is the preferred API from rust. @@ -401,6 +432,8 @@ impl CloudOptions { /// Build the [`object_store::ObjectStore`] implementation for GCP. #[cfg(feature = "gcp")] pub fn build_gcp(&self, url: &str) -> PolarsResult { + use super::credential_provider::IntoCredentialProvider; + let mut builder = GoogleCloudStorageBuilder::from_env(); if let Some(options) = &self.config { let CloudConfig::Gcp(options) = options else { @@ -411,12 +444,18 @@ impl CloudOptions { } } - builder + let builder = builder .with_client_options(get_client_options()) .with_url(url) - .with_retry(get_retry_config(self.max_retries)) - .build() - .map_err(to_compute_err) + .with_retry(get_retry_config(self.max_retries)); + + let builder = if let Some(v) = self.credential_provider.clone() { + builder.with_credentials(v.into_gcp_provider()) + } else { + builder + }; + + builder.build().map_err(to_compute_err) } #[cfg(feature = "http")] diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 1b26f6f35737..3d6d0009ea4d 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -51,7 +51,7 @@ version_check = { workspace = true } [features] # debugging utility debugging = [] -python = ["dep:pyo3", "ciborium"] +python = ["dep:pyo3", "ciborium", "polars-utils/python"] serde = [ "dep:serde", "polars-core/serde-lazy", diff --git a/crates/polars-plan/src/dsl/expr_dyn_fn.rs b/crates/polars-plan/src/dsl/expr_dyn_fn.rs index 2c0acfeeb13b..483dafcc83f1 100644 --- a/crates/polars-plan/src/dsl/expr_dyn_fn.rs +++ b/crates/polars-plan/src/dsl/expr_dyn_fn.rs @@ -71,7 +71,7 @@ impl<'a> Deserialize<'a> for SpecialEq> { { let buf = Vec::::deserialize(deserializer)?; - if buf.starts_with(python_udf::MAGIC_BYTE_MARK) { + if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) { let udf = python_udf::PythonUdfExpression::try_deserialize(&buf) .map_err(|e| D::Error::custom(format!("{e}")))?; Ok(SpecialEq::new(udf)) @@ -399,7 +399,7 @@ impl<'a> Deserialize<'a> for GetOutput { { let buf = Vec::::deserialize(deserializer)?; - if buf.starts_with(python_udf::MAGIC_BYTE_MARK) { + if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) { let get_output = python_udf::PythonGetOutput::try_deserialize(&buf) .map_err(|e| D::Error::custom(format!("{e}")))?; Ok(SpecialEq::new(get_output)) diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index 0f9543a24921..1d8080855e3c 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -1,7 +1,6 @@ use std::io::Cursor; use std::sync::Arc; -use once_cell::sync::Lazy; use polars_core::datatypes::{DataType, Field}; use polars_core::error::*; use polars_core::frame::column::Column; @@ -10,10 +9,6 @@ use polars_core::schema::Schema; use pyo3::prelude::*; use pyo3::pybacked::PyBackedBytes; use pyo3::types::PyBytes; -#[cfg(feature = "serde")] -use serde::ser::Error; -#[cfg(feature = "serde")] -use serde::{Deserialize, Deserializer, Serialize, Serializer}; use super::expr_dyn_fn::*; use crate::constants::MAP_LIST_NAME; @@ -26,88 +21,10 @@ pub static mut CALL_COLUMNS_UDF_PYTHON: Option< pub static mut CALL_DF_UDF_PYTHON: Option< fn(s: DataFrame, lambda: &PyObject) -> PolarsResult, > = None; -#[cfg(feature = "serde")] -pub(super) const MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes(); -static PYTHON_VERSION_MINOR: Lazy = Lazy::new(get_python_minor_version); -#[derive(Debug)] -pub struct PythonFunction(pub PyObject); - -impl Clone for PythonFunction { - fn clone(&self) -> Self { - Python::with_gil(|py| Self(self.0.clone_ref(py))) - } -} - -impl From for PythonFunction { - fn from(value: PyObject) -> Self { - Self(value) - } -} - -impl Eq for PythonFunction {} - -impl PartialEq for PythonFunction { - fn eq(&self, other: &Self) -> bool { - Python::with_gil(|py| { - let eq = self.0.getattr(py, "__eq__").unwrap(); - eq.call1(py, (other.0.clone_ref(py),)) - .unwrap() - .extract::(py) - // equality can be not implemented, so default to false - .unwrap_or(false) - }) - } -} - -#[cfg(feature = "serde")] -impl Serialize for PythonFunction { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "cloudpickle") - .or_else(|_| PyModule::import_bound(py, "pickle")) - .expect("unable to import 'cloudpickle' or 'pickle'") - .getattr("dumps") - .unwrap(); - - let python_function = self.0.clone_ref(py); - - let dumped = pickle - .call1((python_function,)) - .map_err(|s| S::Error::custom(format!("cannot pickle {s}")))?; - let dumped = dumped.extract::().unwrap(); - - serializer.serialize_bytes(&dumped) - }) - } -} - -#[cfg(feature = "serde")] -impl<'a> Deserialize<'a> for PythonFunction { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'a>, - { - use serde::de::Error; - let bytes = Vec::::deserialize(deserializer)?; - - Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "pickle") - .expect("unable to import 'pickle'") - .getattr("loads") - .unwrap(); - let arg = (PyBytes::new_bound(py, &bytes),); - let python_function = pickle - .call1(arg) - .map_err(|s| D::Error::custom(format!("cannot pickle {s}")))?; - - Ok(Self(python_function.into())) - }) - } -} +pub use polars_utils::python_function::{ + PythonFunction, PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON_VERSION_MINOR, +}; pub struct PythonUdfExpression { python_function: PyObject, @@ -134,8 +51,8 @@ impl PythonUdfExpression { #[cfg(feature = "serde")] pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult> { // Handle byte mark - debug_assert!(buf.starts_with(MAGIC_BYTE_MARK)); - let buf = &buf[MAGIC_BYTE_MARK.len()..]; + debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK)); + let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..]; // Handle pickle metadata let use_cloudpickle = buf[0]; @@ -181,7 +98,7 @@ fn from_pyerr(e: PyErr) -> PolarsError { PolarsError::ComputeError(format!("error raised in python: {e}").into()) } -impl DataFrameUdf for PythonFunction { +impl DataFrameUdf for polars_utils::python_function::PythonFunction { fn call_udf(&self, df: DataFrame) -> PolarsResult { let func = unsafe { CALL_DF_UDF_PYTHON.unwrap() }; func(df, &self.0) @@ -215,7 +132,7 @@ impl ColumnsUdf for PythonUdfExpression { #[cfg(feature = "serde")] fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { // Write byte marks - buf.extend_from_slice(MAGIC_BYTE_MARK); + buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK); Python::with_gil(|py| { // Try pickle to serialize the UDF, otherwise fall back to cloudpickle. @@ -273,8 +190,8 @@ impl PythonGetOutput { #[cfg(feature = "serde")] pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult> { // Skip header. - debug_assert!(buf.starts_with(MAGIC_BYTE_MARK)); - let buf = &buf[MAGIC_BYTE_MARK.len()..]; + debug_assert!(buf.starts_with(PYTHON_SERDE_MAGIC_BYTE_MARK)); + let buf = &buf[PYTHON_SERDE_MAGIC_BYTE_MARK.len()..]; let mut reader = Cursor::new(buf); let return_dtype: Option = @@ -302,7 +219,7 @@ impl FunctionOutputField for PythonGetOutput { #[cfg(feature = "serde")] fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { - buf.extend_from_slice(MAGIC_BYTE_MARK); + buf.extend_from_slice(PYTHON_SERDE_MAGIC_BYTE_MARK); ciborium::ser::into_writer(&self.return_dtype, &mut *buf).unwrap(); Ok(()) } @@ -342,17 +259,3 @@ impl Expr { } } } - -/// Get the minor Python version from the `sys` module. -fn get_python_minor_version() -> u8 { - Python::with_gil(|py| { - PyModule::import_bound(py, "sys") - .unwrap() - .getattr("version_info") - .unwrap() - .getattr("minor") - .unwrap() - .extract() - .unwrap() - }) -} diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index fa28a9f8e5ed..7cf154241966 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -257,9 +257,11 @@ impl PyLazyFrame { #[cfg(feature = "parquet")] #[staticmethod] - #[pyo3(signature = (source, sources, n_rows, cache, parallel, rechunk, row_index, - low_memory, cloud_options, use_statistics, hive_partitioning, schema, hive_schema, try_parse_hive_dates, retries, glob, include_file_paths, allow_missing_columns) - )] + #[pyo3(signature = ( + source, sources, n_rows, cache, parallel, rechunk, row_index, low_memory, cloud_options, + credential_provider, use_statistics, hive_partitioning, schema, hive_schema, + try_parse_hive_dates, retries, glob, include_file_paths, allow_missing_columns, + ))] fn new_from_parquet( source: Option, sources: Wrap, @@ -270,6 +272,7 @@ impl PyLazyFrame { row_index: Option<(String, IdxSize)>, low_memory: bool, cloud_options: Option>, + credential_provider: Option, use_statistics: bool, hive_partitioning: Option, schema: Option>, @@ -280,6 +283,8 @@ impl PyLazyFrame { include_file_paths: Option, allow_missing_columns: bool, ) -> PyResult { + use cloud::credential_provider::PlCredentialProvider; + let parallel = parallel.0; let hive_schema = hive_schema.map(|s| Arc::new(s.0)); @@ -322,7 +327,13 @@ impl PyLazyFrame { let first_path_url = first_path.to_string_lossy(); let cloud_options = parse_cloud_options(&first_path_url, cloud_options.unwrap_or_default())?; - args.cloud_options = Some(cloud_options.with_max_retries(retries)); + args.cloud_options = Some( + cloud_options + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ), + ); } let lf = LazyFrame::scan_parquet_sources(sources, args).map_err(PyPolarsErr::from)?; diff --git a/crates/polars-utils/Cargo.toml b/crates/polars-utils/Cargo.toml index 442d319b7753..ef968918d4e8 100644 --- a/crates/polars-utils/Cargo.toml +++ b/crates/polars-utils/Cargo.toml @@ -21,6 +21,7 @@ libc = { workspace = true } memmap = { workspace = true, optional = true } num-traits = { workspace = true } once_cell = { workspace = true } +pyo3 = { workspace = true, optional = true } raw-cpuid = { workspace = true } rayon = { workspace = true } serde = { workspace = true, optional = true } @@ -39,3 +40,4 @@ bigidx = [] nightly = [] ir_serde = ["serde"] serde = ["dep:serde", "serde/derive"] +python = ["pyo3"] diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index eacd517d1254..5c302067e146 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -51,3 +51,6 @@ pub mod partitioned; pub use index::{IdxSize, NullableIdxSize}; pub use io::*; + +#[cfg(feature = "python")] +pub mod python_function; diff --git a/crates/polars-utils/src/python_function.rs b/crates/polars-utils/src/python_function.rs new file mode 100644 index 000000000000..551d7cda70ae --- /dev/null +++ b/crates/polars-utils/src/python_function.rs @@ -0,0 +1,105 @@ +use once_cell::sync::Lazy; +use pyo3::prelude::*; +use pyo3::pybacked::PyBackedBytes; +use pyo3::types::PyBytes; +#[cfg(feature = "serde")] +use serde::ser::Error; +#[cfg(feature = "serde")] +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +#[cfg(feature = "serde")] +pub const PYTHON_SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes(); +pub static PYTHON_VERSION_MINOR: Lazy = Lazy::new(get_python_minor_version); + +#[derive(Debug)] +pub struct PythonFunction(pub PyObject); + +impl Clone for PythonFunction { + fn clone(&self) -> Self { + Python::with_gil(|py| Self(self.0.clone_ref(py))) + } +} + +impl From for PythonFunction { + fn from(value: PyObject) -> Self { + Self(value) + } +} + +impl Eq for PythonFunction {} + +impl PartialEq for PythonFunction { + fn eq(&self, other: &Self) -> bool { + Python::with_gil(|py| { + let eq = self.0.getattr(py, "__eq__").unwrap(); + eq.call1(py, (other.0.clone_ref(py),)) + .unwrap() + .extract::(py) + // equality can be not implemented, so default to false + .unwrap_or(false) + }) + } +} + +#[cfg(feature = "serde")] +impl Serialize for PythonFunction { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + Python::with_gil(|py| { + let pickle = PyModule::import_bound(py, "cloudpickle") + .or_else(|_| PyModule::import_bound(py, "pickle")) + .expect("unable to import 'cloudpickle' or 'pickle'") + .getattr("dumps") + .unwrap(); + + let python_function = self.0.clone_ref(py); + + let dumped = pickle + .call1((python_function,)) + .map_err(|s| S::Error::custom(format!("cannot pickle {s}")))?; + let dumped = dumped.extract::().unwrap(); + + serializer.serialize_bytes(&dumped) + }) + } +} + +#[cfg(feature = "serde")] +impl<'a> Deserialize<'a> for PythonFunction { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + use serde::de::Error; + let bytes = Vec::::deserialize(deserializer)?; + + Python::with_gil(|py| { + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") + .getattr("loads") + .unwrap(); + let arg = (PyBytes::new_bound(py, &bytes),); + let python_function = pickle + .call1(arg) + .map_err(|s| D::Error::custom(format!("cannot pickle {s}")))?; + + Ok(Self(python_function.into())) + }) + } +} + +/// Get the minor Python version from the `sys` module. +fn get_python_minor_version() -> u8 { + Python::with_gil(|py| { + PyModule::import_bound(py, "sys") + .unwrap() + .getattr("version_info") + .unwrap() + .getattr("minor") + .unwrap() + .extract() + .unwrap() + }) +} diff --git a/py-polars/polars/_typing.py b/py-polars/polars/_typing.py index 1670b08aeb2f..b53d67ee8b46 100644 --- a/py-polars/polars/_typing.py +++ b/py-polars/polars/_typing.py @@ -4,7 +4,9 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Literal, + Optional, Protocol, TypedDict, TypeVar, @@ -294,3 +296,7 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any: # LazyFrame engine selection EngineType: TypeAlias = Union[Literal["cpu", "gpu"], "GPUEngine"] + +CredentialProviderFunction: TypeAlias = Callable[ + [], tuple[dict[str, Optional[str]], Optional[int]] +] diff --git a/py-polars/polars/io/parquet/functions.py b/py-polars/polars/io/parquet/functions.py index c27ee99ca94b..80f5e17e9849 100644 --- a/py-polars/polars/io/parquet/functions.py +++ b/py-polars/polars/io/parquet/functions.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from polars import DataFrame, DataType, LazyFrame - from polars._typing import ParallelStrategy, SchemaDict + from polars._typing import CredentialProviderFunction, ParallelStrategy, SchemaDict @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @@ -338,6 +338,7 @@ def scan_parquet( low_memory: bool = False, cache: bool = True, storage_options: dict[str, Any] | None = None, + credential_provider: CredentialProviderFunction | None = None, retries: int = 2, include_file_paths: str | None = None, allow_missing_columns: bool = False, @@ -426,6 +427,14 @@ def scan_parquet( If `storage_options` is not provided, Polars will try to infer the information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. retries Number of retries if accessing a cloud instance fails. include_file_paths @@ -467,6 +476,10 @@ def scan_parquet( msg = "The `hive_schema` parameter of `scan_parquet` is considered unstable." issue_unstable_warning(msg) + if credential_provider is not None: + msg = "The `credential_provider` parameter of `scan_parquet` is considered unstable." + issue_unstable_warning(msg) + if isinstance(source, (str, Path)): source = normalize_filepath(source, check_not_directory=False) elif is_path_or_str_sequence(source): @@ -483,6 +496,7 @@ def scan_parquet( row_index_name=row_index_name, row_index_offset=row_index_offset, storage_options=storage_options, + credential_provider=credential_provider, low_memory=low_memory, use_statistics=use_statistics, hive_partitioning=hive_partitioning, @@ -506,6 +520,7 @@ def _scan_parquet_impl( row_index_name: str | None = None, row_index_offset: int = 0, storage_options: dict[str, object] | None = None, + credential_provider: CredentialProviderFunction | None = None, low_memory: bool = False, use_statistics: bool = True, hive_partitioning: bool | None = None, @@ -539,6 +554,7 @@ def _scan_parquet_impl( parse_row_index_args(row_index_name, row_index_offset), low_memory, cloud_options=storage_options, + credential_provider=credential_provider, use_statistics=use_statistics, hive_partitioning=hive_partitioning, schema=schema,