diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index fc96ce4fc3ef..283d7ff9d703 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -40,7 +40,7 @@ use std::borrow::Cow; use std::process::Command; use std::str; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::{Duration, Instant, SystemTime}; use url::Url; static AZURE_VERSION: HeaderValue = HeaderValue::from_static("2021-08-06"); @@ -293,13 +293,16 @@ fn lexy_sort<'a>( values } +/// #[derive(Deserialize, Debug)] -struct TokenResponse { +struct OAuthTokenResponse { access_token: String, expires_in: u64, } /// Encapsulates the logic to perform an OAuth token challenge +/// +/// #[derive(Debug)] pub struct ClientSecretOAuthProvider { token_url: String, @@ -340,7 +343,7 @@ impl TokenProvider for ClientSecretOAuthProvider { client: &Client, retry: &RetryConfig, ) -> crate::Result>> { - let response: TokenResponse = client + let response: OAuthTokenResponse = client .request(Method::POST, &self.token_url) .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)) .form(&[ @@ -363,21 +366,27 @@ impl TokenProvider for ClientSecretOAuthProvider { } } -fn expires_in_string<'de, D>(deserializer: D) -> std::result::Result +fn expires_on_string<'de, D>(deserializer: D) -> std::result::Result where D: serde::de::Deserializer<'de>, { let v = String::deserialize(deserializer)?; - v.parse::().map_err(serde::de::Error::custom) + let v = v.parse::().map_err(serde::de::Error::custom)?; + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_err(serde::de::Error::custom)?; + + Ok(Instant::now() + Duration::from_secs(v.saturating_sub(now.as_secs()))) } -// NOTE: expires_on is a String version of unix epoch time, not an integer. -// +/// NOTE: expires_on is a String version of unix epoch time, not an integer. +/// +/// #[derive(Debug, Clone, Deserialize)] -struct MsiTokenResponse { +struct ImdsTokenResponse { pub access_token: String, - #[serde(deserialize_with = "expires_in_string")] - pub expires_in: u64, + #[serde(deserialize_with = "expires_on_string")] + pub expires_on: Instant, } /// Attempts authentication using a managed identity that has been assigned to the deployment environment. @@ -450,7 +459,7 @@ impl TokenProvider for ImdsManagedIdentityProvider { builder = builder.header("x-identity-header", val); }; - let response: MsiTokenResponse = builder + let response: ImdsTokenResponse = builder .send_retry(retry) .await .context(TokenRequestSnafu)? @@ -460,12 +469,12 @@ impl TokenProvider for ImdsManagedIdentityProvider { Ok(TemporaryToken { token: Arc::new(AzureCredential::BearerToken(response.access_token)), - expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)), + expiry: Some(response.expires_on), }) } } -/// Credential for using workload identity dfederation +/// Credential for using workload identity federation /// /// #[derive(Debug)] @@ -512,7 +521,7 @@ impl TokenProvider for WorkloadIdentityOAuthProvider { .map_err(|_| Error::FederatedTokenFile)?; // https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow#third-case-access-token-request-with-a-federated-credential - let response: TokenResponse = client + let response: OAuthTokenResponse = client .request(Method::POST, &self.token_url) .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)) .form(&[