From 49edd1af907db5c716d7569ba602a58405055528 Mon Sep 17 00:00:00 2001 From: Krzysztof Piotrowski Date: Mon, 9 Sep 2024 19:00:18 +0000 Subject: [PATCH] fix: handle error when jwt token could not be retrived Signed-off-by: Krzysztof Piotrowski --- .../extensions/c8y_auth_proxy/src/server.rs | 44 +++++++++++++------ .../extensions/c8y_auth_proxy/src/tokens.rs | 14 +++--- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/crates/extensions/c8y_auth_proxy/src/server.rs b/crates/extensions/c8y_auth_proxy/src/server.rs index 24d907bbf7e..c3c54468efb 100644 --- a/crates/extensions/c8y_auth_proxy/src/server.rs +++ b/crates/extensions/c8y_auth_proxy/src/server.rs @@ -256,21 +256,30 @@ async fn proxy_ws( use axum::extract::ws::CloseFrame; use tungstenite::error::Error; let uri = format!("{}/{path}", host.ws); - let mut token = retrieve_token.not_matching(None).await; - let c8y = match connect_to_websocket(&token, &headers, &uri, &host).await { - Ok(c8y) => Ok(c8y), - Err(Error::Http(res)) if res.status() == StatusCode::UNAUTHORIZED => { - token = retrieve_token.not_matching(Some(&token)).await; - match connect_to_websocket(&token, &headers, &uri, &host).await { + + let c8y = { + match retrieve_token.not_matching(None).await { + Ok(token) => match connect_to_websocket(&token, &headers, &uri, &host).await { Ok(c8y) => Ok(c8y), - Err(e) => { - Err(anyhow::Error::from(e).context("Failed to connect to proxied websocket")) + Err(Error::Http(res)) if res.status() == StatusCode::UNAUTHORIZED => { + match retrieve_token.not_matching(Some(&token)).await { + Ok(token) => { + match connect_to_websocket(&token, &headers, &uri, &host).await { + Ok(c8y) => Ok(c8y), + Err(e) => Err(anyhow::Error::from(e) + .context("Failed to connect to proxied websocket")), + } + } + Err(e) => Err(e.context("Failed to retrieve JWT token")), + } } - } + Err(e) => Err(anyhow::Error::from(e)), + }, + Err(e) => Err(e.context("Failed to retrieve JWT token")), } - Err(e) => Err(anyhow::Error::from(e)), } .context("Error connecting to proxied websocket"); + let c8y = match c8y { Err(e) => { let _ = ws @@ -413,7 +422,10 @@ async fn respond_to( destination += query; } - let mut token = retrieve_token.not_matching(None).await; + let mut token = retrieve_token + .not_matching(None) + .await + .with_context(|| "failed to retrieve JWT token")?; if let Some(ws) = ws { let path = path.to_owned(); @@ -429,7 +441,10 @@ async fn respond_to( .await .with_context(|| format!("making HEAD request to {destination}"))?; if response.status() == StatusCode::UNAUTHORIZED { - token = retrieve_token.not_matching(Some(&token)).await; + token = retrieve_token + .not_matching(Some(&token)) + .await + .with_context(|| "failed to retrieve JWT token")?; } } @@ -448,7 +463,10 @@ async fn respond_to( .with_context(|| format!("making proxied request to {destination}"))?; if res.status() == StatusCode::UNAUTHORIZED { - token = retrieve_token.not_matching(Some(&token)).await; + token = retrieve_token + .not_matching(Some(&token)) + .await + .with_context(|| "failed to retrieve JWT token")?; if let Some(body) = body_clone { res = send_request(Body::from(body), &token) .await diff --git a/crates/extensions/c8y_auth_proxy/src/tokens.rs b/crates/extensions/c8y_auth_proxy/src/tokens.rs index a402ff610a3..ab9e2d02d27 100644 --- a/crates/extensions/c8y_auth_proxy/src/tokens.rs +++ b/crates/extensions/c8y_auth_proxy/src/tokens.rs @@ -10,7 +10,7 @@ impl SharedTokenManager { /// Returns a JWT that doesn't match the provided JWT /// /// This prevents needless token refreshes if multiple requests are made in parallel - pub async fn not_matching(&self, input: Option<&Arc>) -> Arc { + pub async fn not_matching(&self, input: Option<&Arc>) -> Result, anyhow::Error> { self.0.lock().await.not_matching(input).await } } @@ -31,17 +31,17 @@ impl TokenManager { } impl TokenManager { - async fn not_matching(&mut self, input: Option<&Arc>) -> Arc { + async fn not_matching(&mut self, input: Option<&Arc>) -> Result, anyhow::Error> { match (self.cached.as_mut(), input) { - (Some(token), None) => token.clone(), + (Some(token), None) => Ok(token.clone()), // The token should have arisen from this TokenManager, so pointer equality is sufficient - (Some(token), Some(no_match)) if !Arc::ptr_eq(token, no_match) => token.clone(), + (Some(token), Some(no_match)) if !Arc::ptr_eq(token, no_match) => Ok(token.clone()), _ => self.refresh().await, } } - async fn refresh(&mut self) -> Arc { - self.cached = Some(self.recv.await_response(()).await.unwrap().unwrap().into()); - self.cached.as_ref().unwrap().clone() + async fn refresh(&mut self) -> Result, anyhow::Error> { + self.cached = Some(self.recv.await_response(()).await??.into()); + Ok(self.cached.as_ref().unwrap().clone()) } }