From 31dc4a2dfb6bbe5f0539b60b7b752e05c18c637f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= <76261501+zecakeh@users.noreply.github.com> Date: Sun, 7 Apr 2024 15:52:43 +0200 Subject: [PATCH] client-api: Add support for the Retry-After header According to MSC4041 / Matrix 1.10 Co-authored-by: Jonas Platte --- crates/ruma-client-api/CHANGELOG.md | 3 + crates/ruma-client-api/Cargo.toml | 3 + crates/ruma-client-api/src/error.rs | 367 ++++++++++++++++-- .../ruma-client-api/src/error/kind_serde.rs | 9 +- crates/ruma-common/src/api/error.rs | 7 + 5 files changed, 353 insertions(+), 36 deletions(-) diff --git a/crates/ruma-client-api/CHANGELOG.md b/crates/ruma-client-api/CHANGELOG.md index 317e9b85cc..b3bad44e1a 100644 --- a/crates/ruma-client-api/CHANGELOG.md +++ b/crates/ruma-client-api/CHANGELOG.md @@ -16,6 +16,9 @@ Breaking changes: - `Error` is now non-exhaustive. - `ErrorKind::Forbidden` is now a non-exhaustive struct variant that can be constructed with `ErrorKind::forbidden()`. +- The `retry_after_ms` field of `ErrorKind::LimitExceeded` was renamed to + `retry_after` and is now an `Option`, to add support for the + Retry-After header, according to MSC4041 / Matrix 1.10 Improvements: diff --git a/crates/ruma-client-api/Cargo.toml b/crates/ruma-client-api/Cargo.toml index 00d30782aa..157c17dd57 100644 --- a/crates/ruma-client-api/Cargo.toml +++ b/crates/ruma-client-api/Cargo.toml @@ -52,6 +52,7 @@ unstable-msc3983 = [] as_variant = { workspace = true } assign = { workspace = true } bytes = "1.0.1" +date_header = "1.0.5" http = { workspace = true } js_int = { workspace = true, features = ["serde"] } js_option = "0.1.1" @@ -61,6 +62,8 @@ ruma-events = { workspace = true } serde = { workspace = true } serde_html_form = { workspace = true } serde_json = { workspace = true } +thiserror = { workspace = true } +web-time = { workspace = true } [dev-dependencies] assert_matches2 = { workspace = true } diff --git a/crates/ruma-client-api/src/error.rs b/crates/ruma-client-api/src/error.rs index 99b942606c..8262da627f 100644 --- a/crates/ruma-client-api/src/error.rs +++ b/crates/ruma-client-api/src/error.rs @@ -1,6 +1,6 @@ //! Errors that can be sent from the homeserver. -use std::{collections::BTreeMap, fmt, sync::Arc, time::Duration}; +use std::{collections::BTreeMap, fmt, str::FromStr, sync::Arc}; use as_variant::as_variant; use bytes::{BufMut, Bytes}; @@ -13,6 +13,8 @@ use ruma_common::{ }; use serde::{Deserialize, Serialize}; use serde_json::{from_slice as from_json_slice, Value as JsonValue}; +use thiserror::Error; +use web_time::{Duration, SystemTime, UNIX_EPOCH}; use crate::PrivOwnedStr; @@ -59,8 +61,8 @@ pub enum ErrorKind { /// M_LIMIT_EXCEEDED LimitExceeded { - /// How long a client should wait in milliseconds before they can try again. - retry_after_ms: Option, + /// How long a client should wait before they can try again. + retry_after: Option, }, /// M_UNKNOWN @@ -350,19 +352,29 @@ impl EndpointError for Error { let body_bytes = &response.body().as_ref(); let error_body: ErrorBody = match from_json_slice(body_bytes) { - Ok(StandardErrorBody { kind, message }) => { - #[cfg(feature = "unstable-msc2967")] - let kind = if let ErrorKind::Forbidden { .. } = kind { - let authenticate = response - .headers() - .get(http::header::WWW_AUTHENTICATE) - .and_then(|val| val.to_str().ok()) - .and_then(AuthenticateError::from_str); - - ErrorKind::Forbidden { authenticate } - } else { - kind - }; + Ok(StandardErrorBody { mut kind, message }) => { + let headers = response.headers(); + + match &mut kind { + #[cfg(feature = "unstable-msc2967")] + ErrorKind::Forbidden { authenticate } => { + *authenticate = headers + .get(http::header::WWW_AUTHENTICATE) + .and_then(|val| val.to_str().ok()) + .and_then(AuthenticateError::from_str); + } + ErrorKind::LimitExceeded { retry_after } => { + // The Retry-After header takes precedence over the retry_after_ms field in + // the body. + if let Some(retry_after_header) = headers + .get(http::header::RETRY_AFTER) + .and_then(RetryAfter::from_header_value) + { + *retry_after = Some(retry_after_header); + } + } + _ => {} + } ErrorBody::Standard { kind, message } } @@ -406,20 +418,24 @@ impl OutgoingResponse for Error { fn try_into_http_response( self, ) -> Result, IntoHttpError> { - let builder = http::Response::builder() + let mut builder = http::Response::builder() .header(http::header::CONTENT_TYPE, "application/json") .status(self.status_code); - #[cfg(feature = "unstable-msc2967")] - let builder = if let ErrorBody::Standard { - kind: ErrorKind::Forbidden { authenticate: Some(auth_error) }, - .. - } = &self.body - { - builder.header(http::header::WWW_AUTHENTICATE, auth_error) - } else { - builder - }; + #[allow(clippy::collapsible_match)] + if let ErrorBody::Standard { kind, .. } = &self.body { + match kind { + #[cfg(feature = "unstable-msc2967")] + ErrorKind::Forbidden { authenticate: Some(auth_error) } => { + builder = builder.header(http::header::WWW_AUTHENTICATE, auth_error); + } + ErrorKind::LimitExceeded { retry_after: Some(retry_after) } => { + let header_value = http::HeaderValue::try_from(retry_after)?; + builder = builder.header(http::header::RETRY_AFTER, header_value); + } + _ => {} + } + } builder .body(match self.body { @@ -532,6 +548,72 @@ impl TryFrom<&AuthenticateError> for http::HeaderValue { } } +/// How long a client should wait before it tries again. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(clippy::exhaustive_enums)] +pub enum RetryAfter { + /// The client should wait for the given duration. + /// + /// This variant should be preferred for backwards compatibility, as it will also populate the + /// `retry_after_ms` field in the body of the response. + Delay(Duration), + /// The client should wait for the given date and time. + DateTime(SystemTime), +} + +impl RetryAfter { + fn from_header_value(value: &http::HeaderValue) -> Option { + let bytes = value.as_bytes(); + + if bytes.iter().all(|b| b.is_ascii_digit()) { + // It should be a duration. + Some(Self::Delay(Duration::from_secs(u64::from_str(value.to_str().ok()?).ok()?))) + } else { + // It should be a date. + let ts = date_header::parse(bytes).ok()?; + Some(Self::DateTime(UNIX_EPOCH.checked_add(Duration::from_secs(ts))?)) + } + } +} + +impl TryFrom<&RetryAfter> for http::HeaderValue { + type Error = RetryAfterInvalidDateTime; + + fn try_from(value: &RetryAfter) -> Result { + match value { + RetryAfter::Delay(duration) => Ok(duration.as_secs().into()), + RetryAfter::DateTime(time) => { + let mut buffer = [0; 29]; + let duration = + time.duration_since(UNIX_EPOCH).map_err(|_| RetryAfterInvalidDateTime)?; + date_header::format(duration.as_secs(), &mut buffer) + .map_err(|_| RetryAfterInvalidDateTime)?; + let value = http::HeaderValue::from_bytes(&buffer) + .expect("date_header should produce a valid header value"); + + Ok(value) + } + } + } +} + +/// An error when converting a [`RetryAfter`] to a [`http::HeaderValue`]. +/// +/// Happens when the `DateTime` is too far in the past (before the Unix epoch) or the +/// future (after the year 9999). +#[derive(Debug, Error)] +#[allow(clippy::exhaustive_structs)] +#[error( + "Retry-After header serialization failed: the datetime is too far in the past or the future" +)] +pub struct RetryAfterInvalidDateTime; + +impl From for IntoHttpError { + fn from(_value: RetryAfterInvalidDateTime) -> Self { + IntoHttpError::RetryAfterInvalidDatetime + } +} + /// Extension trait for `FromHttpResponseError`. pub trait FromHttpResponseErrorExt { /// If `self` is a server error in the `errcode` + `error` format expected @@ -548,9 +630,13 @@ impl FromHttpResponseErrorExt for FromHttpResponseError { #[cfg(test)] mod tests { use assert_matches2::assert_matches; - use serde_json::{from_value as from_json_value, json}; + use ruma_common::api::{EndpointError, OutgoingResponse}; + use serde_json::{ + from_slice as from_json_slice, from_value as from_json_value, json, Value as JsonValue, + }; + use web_time::{Duration, UNIX_EPOCH}; - use super::{ErrorKind, StandardErrorBody}; + use super::{Error, ErrorBody, ErrorKind, RetryAfter, StandardErrorBody}; #[test] fn deserialize_forbidden() { @@ -615,9 +701,7 @@ mod tests { #[cfg(feature = "unstable-msc2967")] #[test] fn deserialize_insufficient_scope() { - use ruma_common::api::EndpointError; - - use super::{AuthenticateError, Error, ErrorBody}; + use super::AuthenticateError; let response = http::Response::builder() .header( @@ -642,4 +726,223 @@ mod tests { assert_matches!(authenticate, Some(AuthenticateError::InsufficientScope { scope })); assert_eq!(scope, "something_privileged"); } + + #[test] + fn deserialize_limit_exceeded_no_retry_after() { + let response = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .body( + serde_json::to_string(&json!({ + "errcode": "M_LIMIT_EXCEEDED", + "error": "Too many requests", + })) + .unwrap(), + ) + .unwrap(); + let error = Error::from_http_response(response); + + assert_eq!(error.status_code, http::StatusCode::TOO_MANY_REQUESTS); + assert_matches!( + error.body, + ErrorBody::Standard { kind: ErrorKind::LimitExceeded { retry_after: None }, message } + ); + assert_eq!(message, "Too many requests"); + } + + #[test] + fn deserialize_limit_exceeded_retry_after_body() { + let response = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .body( + serde_json::to_string(&json!({ + "errcode": "M_LIMIT_EXCEEDED", + "error": "Too many requests", + "retry_after_ms": 2000, + })) + .unwrap(), + ) + .unwrap(); + let error = Error::from_http_response(response); + + assert_eq!(error.status_code, http::StatusCode::TOO_MANY_REQUESTS); + assert_matches!( + error.body, + ErrorBody::Standard { + kind: ErrorKind::LimitExceeded { retry_after: Some(retry_after) }, + message + } + ); + assert_matches!(retry_after, RetryAfter::Delay(delay)); + assert_eq!(delay.as_millis(), 2000); + assert_eq!(message, "Too many requests"); + } + + #[test] + fn deserialize_limit_exceeded_retry_after_header_delay() { + let response = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .header(http::header::RETRY_AFTER, "2") + .body( + serde_json::to_string(&json!({ + "errcode": "M_LIMIT_EXCEEDED", + "error": "Too many requests", + })) + .unwrap(), + ) + .unwrap(); + let error = Error::from_http_response(response); + + assert_eq!(error.status_code, http::StatusCode::TOO_MANY_REQUESTS); + assert_matches!( + error.body, + ErrorBody::Standard { + kind: ErrorKind::LimitExceeded { retry_after: Some(retry_after) }, + message + } + ); + assert_matches!(retry_after, RetryAfter::Delay(delay)); + assert_eq!(delay.as_millis(), 2000); + assert_eq!(message, "Too many requests"); + } + + #[test] + fn deserialize_limit_exceeded_retry_after_header_datetime() { + let response = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .header(http::header::RETRY_AFTER, "Fri, 15 May 2015 15:34:21 GMT") + .body( + serde_json::to_string(&json!({ + "errcode": "M_LIMIT_EXCEEDED", + "error": "Too many requests", + })) + .unwrap(), + ) + .unwrap(); + let error = Error::from_http_response(response); + + assert_eq!(error.status_code, http::StatusCode::TOO_MANY_REQUESTS); + assert_matches!( + error.body, + ErrorBody::Standard { + kind: ErrorKind::LimitExceeded { retry_after: Some(retry_after) }, + message + } + ); + assert_matches!(retry_after, RetryAfter::DateTime(time)); + assert_eq!(time.duration_since(UNIX_EPOCH).unwrap().as_secs(), 1_431_704_061); + assert_eq!(message, "Too many requests"); + } + + #[test] + fn deserialize_limit_exceeded_retry_after_header_over_body() { + let response = http::Response::builder() + .status(http::StatusCode::TOO_MANY_REQUESTS) + .header(http::header::RETRY_AFTER, "2") + .body( + serde_json::to_string(&json!({ + "errcode": "M_LIMIT_EXCEEDED", + "error": "Too many requests", + "retry_after_ms": 3000, + })) + .unwrap(), + ) + .unwrap(); + let error = Error::from_http_response(response); + + assert_eq!(error.status_code, http::StatusCode::TOO_MANY_REQUESTS); + assert_matches!( + error.body, + ErrorBody::Standard { + kind: ErrorKind::LimitExceeded { retry_after: Some(retry_after) }, + message + } + ); + assert_matches!(retry_after, RetryAfter::Delay(delay)); + assert_eq!(delay.as_millis(), 2000); + assert_eq!(message, "Too many requests"); + } + + #[test] + fn serialize_limit_exceeded_retry_after_none() { + let error = Error::new( + http::StatusCode::TOO_MANY_REQUESTS, + ErrorBody::Standard { + kind: ErrorKind::LimitExceeded { retry_after: None }, + message: "Too many requests".to_owned(), + }, + ); + + let response = error.try_into_http_response::>().unwrap(); + + assert_eq!(response.status(), http::StatusCode::TOO_MANY_REQUESTS); + assert_eq!(response.headers().get(http::header::RETRY_AFTER), None); + + let json_body: JsonValue = from_json_slice(response.body()).unwrap(); + assert_eq!( + json_body, + json!({ + "errcode": "M_LIMIT_EXCEEDED", + "error": "Too many requests", + }) + ); + } + + #[test] + fn serialize_limit_exceeded_retry_after_delay() { + let error = Error::new( + http::StatusCode::TOO_MANY_REQUESTS, + ErrorBody::Standard { + kind: ErrorKind::LimitExceeded { + retry_after: Some(RetryAfter::Delay(Duration::from_secs(3))), + }, + message: "Too many requests".to_owned(), + }, + ); + + let response = error.try_into_http_response::>().unwrap(); + + assert_eq!(response.status(), http::StatusCode::TOO_MANY_REQUESTS); + let retry_after_header = response.headers().get(http::header::RETRY_AFTER).unwrap(); + assert_eq!(retry_after_header.to_str().unwrap(), "3"); + + let json_body: JsonValue = from_json_slice(response.body()).unwrap(); + assert_eq!( + json_body, + json!({ + "errcode": "M_LIMIT_EXCEEDED", + "error": "Too many requests", + "retry_after_ms": 3000, + }) + ); + } + + #[test] + fn serialize_limit_exceeded_retry_after_datetime() { + let error = Error::new( + http::StatusCode::TOO_MANY_REQUESTS, + ErrorBody::Standard { + kind: ErrorKind::LimitExceeded { + retry_after: Some(RetryAfter::DateTime( + UNIX_EPOCH + Duration::from_secs(1_431_704_061), + )), + }, + message: "Too many requests".to_owned(), + }, + ); + + let response = error.try_into_http_response::>().unwrap(); + + assert_eq!(response.status(), http::StatusCode::TOO_MANY_REQUESTS); + let retry_after_header = response.headers().get(http::header::RETRY_AFTER).unwrap(); + assert_eq!(retry_after_header.to_str().unwrap(), "Fri, 15 May 2015 15:34:21 GMT"); + + let json_body: JsonValue = from_json_slice(response.body()).unwrap(); + assert_eq!( + json_body, + json!({ + "errcode": "M_LIMIT_EXCEEDED", + "error": "Too many requests", + }) + ); + } } diff --git a/crates/ruma-client-api/src/error/kind_serde.rs b/crates/ruma-client-api/src/error/kind_serde.rs index db27a5fd96..65aa49de1b 100644 --- a/crates/ruma-client-api/src/error/kind_serde.rs +++ b/crates/ruma-client-api/src/error/kind_serde.rs @@ -13,7 +13,7 @@ use serde::{ }; use serde_json::from_value as from_json_value; -use super::{ErrorKind, Extra}; +use super::{ErrorKind, Extra, RetryAfter}; use crate::PrivOwnedStr; enum Field<'de> { @@ -178,12 +178,13 @@ impl<'de> Visitor<'de> for ErrorKindVisitor { ErrCode::NotJson => ErrorKind::NotJson, ErrCode::NotFound => ErrorKind::NotFound, ErrCode::LimitExceeded => ErrorKind::LimitExceeded { - retry_after_ms: retry_after_ms + retry_after: retry_after_ms .map(from_json_value::) .transpose() .map_err(de::Error::custom)? .map(Into::into) - .map(Duration::from_millis), + .map(Duration::from_millis) + .map(RetryAfter::Delay), }, ErrCode::Unknown => ErrorKind::Unknown, ErrCode::Unrecognized => ErrorKind::Unrecognized, @@ -328,7 +329,7 @@ impl Serialize for ErrorKind { Self::UnknownToken { soft_logout: true } => { st.serialize_entry("soft_logout", &true)?; } - Self::LimitExceeded { retry_after_ms: Some(duration) } => { + Self::LimitExceeded { retry_after: Some(RetryAfter::Delay(duration)) } => { st.serialize_entry( "retry_after_ms", &UInt::try_from(duration.as_millis()).map_err(ser::Error::custom)?, diff --git a/crates/ruma-common/src/api/error.rs b/crates/ruma-common/src/api/error.rs index f6176f690c..0024410567 100644 --- a/crates/ruma-common/src/api/error.rs +++ b/crates/ruma-common/src/api/error.rs @@ -129,6 +129,13 @@ pub enum IntoHttpError { #[error("header serialization failed: {0}")] Header(#[from] http::header::InvalidHeaderValue), + /// Retry-After header serialization failed because the datetime provided is after the year + /// 9999. + #[error( + "Retry-After header serialization failed: the year of the datetime is bigger than 9999" + )] + RetryAfterInvalidDatetime, + /// HTTP request construction failed. #[error("HTTP request construction failed: {0}")] Http(#[from] http::Error),