From e92e3a6668e7e9862715b24ef839e9cb33b88f00 Mon Sep 17 00:00:00 2001 From: Shing Him Ng Date: Fri, 3 Jan 2025 09:03:31 -0600 Subject: [PATCH] Introduce DbPoolError to store Redis and timeout errors --- payjoin-directory/src/db.rs | 22 ++++++++++++++++++---- payjoin-directory/src/error.rs | 25 +++++++++++++++++++++++++ payjoin-directory/src/lib.rs | 23 +++++++++++++---------- 3 files changed, 56 insertions(+), 14 deletions(-) create mode 100644 payjoin-directory/src/error.rs diff --git a/payjoin-directory/src/db.rs b/payjoin-directory/src/db.rs index 6165abf9..d896c938 100644 --- a/payjoin-directory/src/db.rs +++ b/payjoin-directory/src/db.rs @@ -3,6 +3,7 @@ use std::time::Duration; use futures::StreamExt; use redis::{AsyncCommands, Client, ErrorKind, RedisError, RedisResult}; use tracing::debug; +use crate::error::DbPoolError; const DEFAULT_COLUMN: &str = ""; const PJ_V1_COLUMN: &str = "pjv1"; @@ -19,11 +20,15 @@ impl DbPool { Ok(Self { client, timeout }) } + /// Peek using [`DEFAULT_COLUMN`] as the channel type. pub async fn push_default(&self, subdirectory_id: &str, data: Vec) -> RedisResult<()> { self.push(subdirectory_id, DEFAULT_COLUMN, data).await } - pub async fn peek_default(&self, subdirectory_id: &str) -> Option>> { + pub async fn peek_default( + &self, + subdirectory_id: &str, + ) -> Result, DbPoolError> { self.peek_with_timeout(subdirectory_id, DEFAULT_COLUMN).await } @@ -31,7 +36,8 @@ impl DbPool { self.push(subdirectory_id, PJ_V1_COLUMN, data).await } - pub async fn peek_v1(&self, subdirectory_id: &str) -> Option>> { + /// Peek using [`PJ_V1_COLUMN`] as the channel type. + pub async fn peek_v1(&self, subdirectory_id: &str) -> Result, DbPoolError> { self.peek_with_timeout(subdirectory_id, PJ_V1_COLUMN).await } @@ -52,8 +58,16 @@ impl DbPool { &self, subdirectory_id: &str, channel_type: &str, - ) -> Option>> { - tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await.ok() + ) -> Result, DbPoolError> { + match tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await { + Ok(redis_result) => { + match redis_result { + Ok(result) => { Ok(result) } + Err(redis_err) => { Err(DbPoolError::Redis(redis_err))} + } + } + Err(elapsed) => {Err(DbPoolError::Timeout(elapsed))} + } } async fn peek(&self, subdirectory_id: &str, channel_type: &str) -> RedisResult> { diff --git a/payjoin-directory/src/error.rs b/payjoin-directory/src/error.rs new file mode 100644 index 00000000..9532500a --- /dev/null +++ b/payjoin-directory/src/error.rs @@ -0,0 +1,25 @@ +#[derive(Debug)] +pub enum DbPoolError { + Redis(redis::RedisError), + Timeout(tokio::time::error::Elapsed) +} + +impl std::fmt::Display for DbPoolError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use DbPoolError::*; + + match &self { + Redis(error) => write!(f, "Redis error: {}", error), + Timeout(timeout) => write!(f, "Timeout: {}", timeout) + } + } +} + +impl std::error::Error for DbPoolError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + DbPoolError::Redis(e) => Some(e), + DbPoolError::Timeout(e) => Some(e), + } + } +} diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 9a1c651c..cf91be48 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -34,7 +34,10 @@ const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message" const ID_LENGTH: usize = 13; mod db; +mod error; + use crate::db::DbPool; +use crate::error::DbPoolError; #[cfg(feature = "_danger-local-https")] type BoxError = Box; @@ -341,11 +344,11 @@ async fn post_fallback_v1( .await .map_err(|e| HandlerError::BadRequest(e.into()))?; match pool.peek_v1(id).await { - Some(result) => match result { - Ok(buffered_req) => Ok(Response::new(full(buffered_req))), - Err(e) => Err(HandlerError::BadRequest(e.into())), - }, - None => Ok(none_response), + Ok(buffered_req) => Ok(Response::new(full(buffered_req))), + Err(e) => match e { + DbPoolError::Redis(_) => Err(HandlerError::BadRequest(e.into())), + DbPoolError::Timeout(_) => Ok(none_response) + } } } @@ -409,11 +412,11 @@ async fn get_subdir( trace!("get_subdir"); let id = check_id_length(id)?; match pool.peek_default(id).await { - Some(result) => match result { - Ok(buffered_req) => Ok(Response::new(full(buffered_req))), - Err(e) => Err(HandlerError::BadRequest(e.into())), - }, - None => Ok(Response::builder().status(StatusCode::ACCEPTED).body(empty())?), + Ok(buffered_req) => Ok(Response::new(full(buffered_req))), + Err(e) => match e { + DbPoolError::Redis(_) => Err(HandlerError::BadRequest(e.into())), + DbPoolError::Timeout(_) => Ok(Response::builder().status(StatusCode::ACCEPTED).body(empty())?), + } } }