Skip to content

Commit

Permalink
fix(catalog/rest): Ensure token been reused correctly (#801)
Browse files Browse the repository at this point in the history
* fix(catalog/rest): Ensure token been reused correctly

Signed-off-by: Xuanwo <[email protected]>

* Fix oauth test

Signed-off-by: Xuanwo <[email protected]>

* Fix tests

Signed-off-by: Xuanwo <[email protected]>

---------

Signed-off-by: Xuanwo <[email protected]>
  • Loading branch information
Xuanwo authored Dec 14, 2024
1 parent 748d37c commit 54926a2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
17 changes: 10 additions & 7 deletions crates/catalog/rest/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,10 @@ impl RestCatalog {
async fn context(&self) -> Result<&RestContext> {
self.ctx
.get_or_try_init(|| async {
let catalog_config = RestCatalog::load_config(&self.user_config).await?;
let client = HttpClient::new(&self.user_config)?;
let catalog_config = RestCatalog::load_config(&client, &self.user_config).await?;
let config = self.user_config.clone().merge_with_config(catalog_config);
let client = HttpClient::new(&config)?;
let client = client.update_with(&config)?;

Ok(RestContext { config, client })
})
Expand All @@ -268,9 +269,10 @@ impl RestCatalog {
/// Load the runtime config from the server by user_config.
///
/// It's required for a rest catalog to update it's config after creation.
async fn load_config(user_config: &RestCatalogConfig) -> Result<CatalogConfig> {
let client = HttpClient::new(user_config)?;

async fn load_config(
client: &HttpClient,
user_config: &RestCatalogConfig,
) -> Result<CatalogConfig> {
let mut request = client.request(Method::GET, user_config.config_endpoint());

if let Some(warehouse_location) = &user_config.warehouse {
Expand All @@ -280,6 +282,7 @@ impl RestCatalog {
let config = client
.query::<CatalogConfig, ErrorResponse, OK>(request.build()?)
.await?;

Ok(config)
}

Expand Down Expand Up @@ -777,7 +780,7 @@ mod tests {
"expires_in": 86400
}"#,
)
.expect(2)
.expect(1)
.create_async()
.await
}
Expand Down Expand Up @@ -831,7 +834,7 @@ mod tests {
"expires_in": 86400
}"#,
)
.expect(2)
.expect(1)
.create_async()
.await;

Expand Down
27 changes: 27 additions & 0 deletions crates/catalog/rest/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ impl Debug for HttpClient {
}

impl HttpClient {
/// Create a new http client.
pub fn new(cfg: &RestCatalogConfig) -> Result<Self> {
Ok(HttpClient {
client: Client::new(),
Expand All @@ -66,6 +67,32 @@ impl HttpClient {
})
}

/// Update the http client with new configuration.
///
/// If cfg carries new value, we will use cfg instead.
/// Otherwise, we will keep the old value.
pub fn update_with(self, cfg: &RestCatalogConfig) -> Result<Self> {
Ok(HttpClient {
client: self.client,

token: Mutex::new(
cfg.token()
.or_else(|| self.token.into_inner().ok().flatten()),
),
token_endpoint: (!cfg.get_token_endpoint().is_empty())
.then(|| cfg.get_token_endpoint())
.unwrap_or(self.token_endpoint),
credential: cfg.credential().or(self.credential),
extra_headers: (!cfg.extra_headers()?.is_empty())
.then(|| cfg.extra_headers())
.transpose()?
.unwrap_or(self.extra_headers),
extra_oauth_params: (!cfg.extra_oauth_params().is_empty())
.then(|| cfg.extra_oauth_params())
.unwrap_or(self.extra_oauth_params),
})
}

/// This API is testing only to assert the token.
#[cfg(test)]
pub(crate) async fn token(&self) -> Option<String> {
Expand Down

0 comments on commit 54926a2

Please sign in to comment.