From 60d590f1c4a739405b0b377c043dd3f566202621 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Tue, 7 Jan 2025 22:15:04 -0800 Subject: [PATCH] re-introduce custom token cache --- Cargo.lock | 2 +- Cargo.toml | 3 +- crates/arroyo-storage/src/aws.rs | 52 +++++++++++++++++++++----------- 3 files changed, 37 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a6a08c3db..2f3bc2a77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6312,7 +6312,7 @@ dependencies = [ [[package]] name = "object_store" version = "0.11.1" -source = "git+http://github.com/ArroyoSystems/arrow-rs?branch=object_store_0.11.1%2Farroyo#4cfe48061503161e43cd3cd7960e74ce789bd3b9" +source = "git+http://github.com/ArroyoSystems/arrow-rs?branch=public_token_cache#8b0175855610c7e895546afc8df966a4aeabee88" dependencies = [ "async-trait", "base64 0.22.1", diff --git a/Cargo.toml b/Cargo.toml index 1e6728393..bbe333bf9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,7 +90,8 @@ datafusion-functions-window = {git = 'https://github.com/ArroyoSystems/arrow-dat datafusion-functions-json = {git = 'https://github.com/ArroyoSystems/datafusion-functions-json', branch = 'datafusion_43'} -object_store = { git = 'http://github.com/ArroyoSystems/arrow-rs', branch = 'object_store_0.11.1/arroyo' } +# object_store = { git = 'http://github.com/ArroyoSystems/arrow-rs', branch = 'object_store_0.11.1/arroyo' } +object_store = { git = 'http://github.com/ArroyoSystems/arrow-rs', branch = 'public_token_cache' } cornucopia_async = { git = "https://github.com/ArroyoSystems/cornucopia", branch = "sqlite" } cornucopia = { git = "https://github.com/ArroyoSystems/cornucopia", branch = "sqlite" } diff --git a/crates/arroyo-storage/src/aws.rs b/crates/arroyo-storage/src/aws.rs index 936bbd545..c087df9e4 100644 --- a/crates/arroyo-storage/src/aws.rs +++ b/crates/arroyo-storage/src/aws.rs @@ -1,14 +1,15 @@ use crate::StorageError; -use aws_config::identity::IdentityCache; use aws_config::timeout::TimeoutConfig; use aws_config::{BehaviorVersion, SdkConfig}; use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider}; -use object_store::{aws::AwsCredential, CredentialProvider}; +use object_store::{aws::AwsCredential, CredentialProvider, TemporaryToken, TokenCache}; +use std::error::Error; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use tokio::sync::OnceCell; pub struct ArroyoCredentialProvider { + cache: TokenCache>, provider: SharedCredentialsProvider, } @@ -31,11 +32,6 @@ async fn get_config<'a>() -> &'a SdkConfig { .operation_attempt_timeout(Duration::from_secs(5)) .build(), ) - .identity_cache( - IdentityCache::lazy() - .buffer_time(Duration::from_secs(60 * 5)) - .build(), - ) .load() .await, ) @@ -57,6 +53,7 @@ impl ArroyoCredentialProvider { .clone(); Ok(Self { + cache: Default::default(), provider: credentials, }) } @@ -66,21 +63,40 @@ impl ArroyoCredentialProvider { } } +async fn get_token( + provider: &SharedCredentialsProvider, +) -> Result>, Box> { + let creds = provider + .provide_credentials() + .await + .map_err(|e| object_store::Error::Generic { + store: "S3", + source: Box::new(e), + })?; + let expiry = creds + .expiry() + .map(|exp| Instant::now() + exp.elapsed().unwrap_or_default()); + Ok(TemporaryToken { + token: Arc::new(AwsCredential { + key_id: creds.access_key_id().to_string(), + secret_key: creds.secret_access_key().to_string(), + token: creds.session_token().map(ToString::to_string), + }), + expiry, + }) +} + #[async_trait::async_trait] impl CredentialProvider for ArroyoCredentialProvider { type Credential = AwsCredential; async fn get_credential(&self) -> object_store::Result> { - let creds = self.provider.provide_credentials().await.map_err(|e| { - object_store::Error::Generic { + self.cache + .get_or_insert_with(|| get_token(&self.provider)) + .await + .map_err(|e| object_store::Error::Generic { store: "S3", - source: Box::new(e), - } - })?; - Ok(Arc::new(AwsCredential { - key_id: creds.access_key_id().to_string(), - secret_key: creds.secret_access_key().to_string(), - token: creds.session_token().map(ToString::to_string), - })) + source: e, + }) } }