From 3c2bc65330bac802ef22837eebbafe7ccea554f2 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 9 Dec 2024 19:57:49 +0000 Subject: [PATCH 01/22] add workspace dependencies --- Cargo.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index cca776ce..7fda3474 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -110,3 +110,10 @@ webauthn-rs = { version = "0.5", features = [ "danger-allow-state-serialisation", "danger-credential-internals" ] } webauthn-rs-proto = "0.5" + +atrium-api = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" } +atrium-common = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" } +atrium-identity = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" } +atrium-oauth-client = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" } + +hickory-resolver = { version = "0.24.1", features = ["tokio", "tokio-rustls", "rustls"] } From 8a595993bd407f0dd43115ba04b6826350aefd10 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 9 Dec 2024 20:09:52 +0000 Subject: [PATCH 02/22] implement DNS text resolver --- src/models/Cargo.toml | 6 ++++++ src/models/src/entity/atproto.rs | 25 +++++++++++++++++++++++++ src/models/src/entity/mod.rs | 1 + 3 files changed, 32 insertions(+) create mode 100644 src/models/src/entity/atproto.rs diff --git a/src/models/Cargo.toml b/src/models/Cargo.toml index f80e2f9e..f9458e4a 100644 --- a/src/models/Cargo.toml +++ b/src/models/Cargo.toml @@ -81,6 +81,12 @@ validator = { workspace = true } webauthn-rs = { workspace = true } webauthn-rs-proto = { workspace = true } +atrium-api = { workspace = true } +atrium-identity = { workspace = true } +atrium-oauth-client = { workspace = true } + +hickory-resolver = { workspace = true } + [dev-dependencies] pretty_assertions = "1" rstest = "0.18.2" diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs new file mode 100644 index 00000000..389d716e --- /dev/null +++ b/src/models/src/entity/atproto.rs @@ -0,0 +1,25 @@ +use atrium_identity::handle::DnsTxtResolver as DnsTxtResolverTrait; +use hickory_resolver::{proto::rr::rdata::TXT, TokioAsyncResolver}; + +struct DnsTxtResolver { + resolver: TokioAsyncResolver, +} + +impl Default for DnsTxtResolver { + fn default() -> Self { + Self { + resolver: TokioAsyncResolver::tokio_from_system_conf() + .expect("failed to create resolver"), + } + } +} + +impl DnsTxtResolverTrait for DnsTxtResolver { + async fn resolve( + &self, + query: &str, + ) -> Result, Box> { + let txt_lookup = self.resolver.txt_lookup(query).await?; + Ok(txt_lookup.iter().map(TXT::to_string).collect()) + } +} diff --git a/src/models/src/entity/mod.rs b/src/models/src/entity/mod.rs index 8e100e7c..33eea604 100644 --- a/src/models/src/entity/mod.rs +++ b/src/models/src/entity/mod.rs @@ -5,6 +5,7 @@ use sqlx::query; pub mod api_keys; pub mod app_version; +pub mod atproto; pub mod auth_codes; pub mod auth_providers; pub mod clients; From 6d5196b8e168bcaf4dc0ee6b635e6bd2b16817c1 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 9 Dec 2024 20:11:26 +0000 Subject: [PATCH 03/22] declare ATProto client type --- src/models/src/entity/atproto.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs index 389d716e..11b24c5f 100644 --- a/src/models/src/entity/atproto.rs +++ b/src/models/src/entity/atproto.rs @@ -1,6 +1,19 @@ -use atrium_identity::handle::DnsTxtResolver as DnsTxtResolverTrait; +use atrium_identity::{ + did::CommonDidResolver, + handle::{AtprotoHandleResolver, DnsTxtResolver as DnsTxtResolverTrait}, +}; +use atrium_oauth_client::{DefaultHttpClient, OAuthClient}; use hickory_resolver::{proto::rr::rdata::TXT, TokioAsyncResolver}; +use crate::database::DB; + +type AtprotoClient = OAuthClient< + DB, + DB, + CommonDidResolver, + AtprotoHandleResolver, +>; + struct DnsTxtResolver { resolver: TokioAsyncResolver, } From ef5c1e1aeff672746f8b4550eeb22c39df3f05a9 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 9 Dec 2024 20:58:40 +0000 Subject: [PATCH 04/22] add ATProto client to app state --- src/models/Cargo.toml | 1 + src/models/src/app_state.rs | 56 +++++++++++++++++++++++++++ src/models/src/database.rs | 65 +++++++++++++++++++++++++++++++- src/models/src/entity/atproto.rs | 4 +- 4 files changed, 123 insertions(+), 3 deletions(-) diff --git a/src/models/Cargo.toml b/src/models/Cargo.toml index f9458e4a..45fc004e 100644 --- a/src/models/Cargo.toml +++ b/src/models/Cargo.toml @@ -82,6 +82,7 @@ webauthn-rs = { workspace = true } webauthn-rs-proto = { workspace = true } atrium-api = { workspace = true } +atrium-common = { workspace = true } atrium-identity = { workspace = true } atrium-oauth-client = { workspace = true } diff --git a/src/models/src/app_state.rs b/src/models/src/app_state.rs index 0e83c9ea..ed7a8fc9 100644 --- a/src/models/src/app_state.rs +++ b/src/models/src/app_state.rs @@ -1,8 +1,16 @@ +use crate::database::DB; use crate::email::EMail; +use crate::entity::atproto::{AtprotoClient, DnsTxtResolver}; use crate::events::event::Event; use crate::events::ip_blacklist_handler::IpBlacklistReq; use crate::events::listener::EventRouterMsg; use crate::ListenScheme; +use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}; +use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig}; +use atrium_oauth_client::{ + AtprotoClientMetadata, AuthMethod, DefaultHttpClient, GrantType, KnownScope, OAuthClient, + OAuthClientConfig, OAuthResolverConfig, Scope, +}; use rauthy_common::constants::PROXY_MODE; use std::env; use std::sync::Arc; @@ -31,6 +39,7 @@ pub struct AppState { pub tx_events_router: flume::Sender, pub tx_ip_blacklist: flume::Sender, pub webauthn: Arc, + pub atproto: Arc, } impl AppState { @@ -154,6 +163,52 @@ impl AppState { .rp_name(&rp_name); let webauthn = Arc::new(builder.build().expect("Invalid configuration")); + let atproto = { + let http_client = Arc::new(DefaultHttpClient::default()); + + let listen_scheme = match listen_scheme { + ListenScheme::Http | ListenScheme::UnixHttp => "http", + ListenScheme::Https | ListenScheme::HttpHttps | ListenScheme::UnixHttps => "https", + }; + + let client_metadata = AtprotoClientMetadata { + client_id: format!( + "{listen_scheme}://{public_url}/auth/v1/atproto/client_metadata" + ), + client_uri: format!("{listen_scheme}://{public_url}"), + redirect_uris: vec![format!( + "{listen_scheme}://{public_url}/auth/v1/atproto/callback" + )], + token_endpoint_auth_method: AuthMethod::None, + grant_types: vec![GrantType::AuthorizationCode], + scopes: vec![Scope::Known(KnownScope::Atproto)], + jwks_uri: None, + token_endpoint_auth_signing_alg: None, + }; + + Arc::new( + OAuthClient::new(OAuthClientConfig { + client_metadata: client_metadata.clone(), + keys: None, + resolver: OAuthResolverConfig { + did_resolver: CommonDidResolver::new(CommonDidResolverConfig { + plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), + http_client: http_client.clone(), + }), + handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { + dns_txt_resolver: DnsTxtResolver::default(), + http_client: http_client.clone(), + }), + authorization_server_metadata: Default::default(), + protected_resource_metadata: Default::default(), + }, + state_store: DB, + session_store: DB, + }) + .expect("failed to initialize atproto client"), + ) + }; + Ok(Self { public_url, argon2_params, @@ -170,6 +225,7 @@ impl AppState { tx_events_router, tx_ip_blacklist, webauthn, + atproto, }) } diff --git a/src/models/src/database.rs b/src/models/src/database.rs index 01ad66c5..f628ddd6 100644 --- a/src/models/src/database.rs +++ b/src/models/src/database.rs @@ -3,8 +3,14 @@ use crate::entity::db_version::DbVersion; use crate::migration::db_migrate_dev::migrate_dev_data; use crate::migration::{anti_lockout, db_migrate, init_prod}; use actix_web::web; +use atrium_api::types::string::Did; +use atrium_common::store::Store; +use atrium_oauth_client::store::session::{Session, SessionStore}; +use atrium_oauth_client::store::state::{InternalStateData, StateStore}; use hiqlite::NodeConfig; -use rauthy_common::constants::{DATABASE_URL, DEV_MODE}; +use rauthy_common::constants::{ + CACHE_TTL_AUTH_PROVIDER_CALLBACK, CACHE_TTL_SESSION, DATABASE_URL, DEV_MODE, +}; use rauthy_common::{is_hiqlite, is_postgres}; use rauthy_error::{ErrorResponse, ErrorResponseType}; use serde::{Deserialize, Serialize}; @@ -232,3 +238,60 @@ impl DB { Ok(()) } } + +impl Store for DB { + type Error = hiqlite::Error; + + async fn get(&self, key: &String) -> Result, Self::Error> { + Self::client().get(Cache::AuthProviderCallback, key).await + } + + async fn set(&self, key: String, value: InternalStateData) -> Result<(), Self::Error> { + Self::client() + .put( + Cache::AuthProviderCallback, + key, + &value, + CACHE_TTL_AUTH_PROVIDER_CALLBACK, + ) + .await + } + + async fn del(&self, key: &String) -> Result<(), Self::Error> { + Self::client() + .delete(Cache::AuthProviderCallback, key.to_string()) + .await + } + + async fn clear(&self) -> Result<(), Self::Error> { + Self::client() + .clear_cache(Cache::AuthProviderCallback) + .await + } +} + +impl StateStore for DB {} + +impl Store for DB { + type Error = hiqlite::Error; + + async fn get(&self, key: &Did) -> Result, Self::Error> { + Self::client().get(Cache::Session, key.to_string()).await + } + + async fn set(&self, key: Did, value: Session) -> Result<(), Self::Error> { + Self::client() + .put(Cache::Session, key.to_string(), &value, CACHE_TTL_SESSION) + .await + } + + async fn del(&self, key: &Did) -> Result<(), Self::Error> { + Self::client().delete(Cache::Session, key.to_string()).await + } + + async fn clear(&self) -> Result<(), Self::Error> { + Self::client().clear_cache(Cache::Session).await + } +} + +impl SessionStore for DB {} diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs index 11b24c5f..dacef72e 100644 --- a/src/models/src/entity/atproto.rs +++ b/src/models/src/entity/atproto.rs @@ -7,14 +7,14 @@ use hickory_resolver::{proto::rr::rdata::TXT, TokioAsyncResolver}; use crate::database::DB; -type AtprotoClient = OAuthClient< +pub type AtprotoClient = OAuthClient< DB, DB, CommonDidResolver, AtprotoHandleResolver, >; -struct DnsTxtResolver { +pub struct DnsTxtResolver { resolver: TokioAsyncResolver, } From ec1cf45d5593d3e7aae05eb776f9e2831598d85f Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 9 Dec 2024 20:59:26 +0000 Subject: [PATCH 05/22] introduce ATProto callback trait --- src/models/src/entity/atproto.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs index dacef72e..083e1b8e 100644 --- a/src/models/src/entity/atproto.rs +++ b/src/models/src/entity/atproto.rs @@ -1,11 +1,14 @@ +use actix_web::{cookie::Cookie, http::header::HeaderValue, web}; use atrium_identity::{ did::CommonDidResolver, handle::{AtprotoHandleResolver, DnsTxtResolver as DnsTxtResolverTrait}, }; use atrium_oauth_client::{DefaultHttpClient, OAuthClient}; use hickory_resolver::{proto::rr::rdata::TXT, TokioAsyncResolver}; +use rauthy_api_types::atproto; +use rauthy_error::ErrorResponse; -use crate::database::DB; +use crate::{app_state::AppState, database::DB}; pub type AtprotoClient = OAuthClient< DB, @@ -36,3 +39,10 @@ impl DnsTxtResolverTrait for DnsTxtResolver { Ok(txt_lookup.iter().map(TXT::to_string).collect()) } } + +pub trait AtprotoCallback { + async fn login_start( + data: &web::Data, + payload: atproto::LoginRequest, + ) -> Result<(Cookie<'_>, String, HeaderValue), ErrorResponse>; +} From 0bf67880dcdf36379d9544940888509d7805f0d3 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 9 Dec 2024 21:02:11 +0000 Subject: [PATCH 06/22] declare ATProto login request type --- src/api_types/src/atproto.rs | 21 +++++++++++++++++++++ src/api_types/src/lib.rs | 1 + 2 files changed, 22 insertions(+) create mode 100644 src/api_types/src/atproto.rs diff --git a/src/api_types/src/atproto.rs b/src/api_types/src/atproto.rs new file mode 100644 index 00000000..a2ab408a --- /dev/null +++ b/src/api_types/src/atproto.rs @@ -0,0 +1,21 @@ +use rauthy_common::constants::RE_URI; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; +use validator::Validate; + +#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)] +pub struct LoginRequest { + /// Validation: + /// `^(did:[a-z]+:[a-zA-Z0-9._:%-]*[a-zA-Z0-9._-]|([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)$` + pub at_id: String, + /// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$` + #[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))] + pub redirect_uri: String, + /// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$` + #[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))] + pub state: Option, + + /// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$` + #[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))] + pub pkce_challenge: String, +} diff --git a/src/api_types/src/lib.rs b/src/api_types/src/lib.rs index 2bd6b09e..58b8ef45 100644 --- a/src/api_types/src/lib.rs +++ b/src/api_types/src/lib.rs @@ -1,4 +1,5 @@ pub mod api_keys; +pub mod atproto; pub mod auth_providers; pub mod blacklist; pub mod clients; From 9a19d401a06783019138b00491ebc02d0b3b1af9 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 9 Dec 2024 21:06:54 +0000 Subject: [PATCH 07/22] implement ATProto callback trait --- src/models/src/entity/atproto.rs | 65 +++++++++++++++++++++++-- src/models/src/entity/auth_providers.rs | 2 +- 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs index 083e1b8e..3b87036f 100644 --- a/src/models/src/entity/atproto.rs +++ b/src/models/src/entity/atproto.rs @@ -3,12 +3,17 @@ use atrium_identity::{ did::CommonDidResolver, handle::{AtprotoHandleResolver, DnsTxtResolver as DnsTxtResolverTrait}, }; -use atrium_oauth_client::{DefaultHttpClient, OAuthClient}; +use atrium_oauth_client::{AuthorizeOptions, DefaultHttpClient, OAuthClient}; +use cryptr::utils::secure_random_alnum; use hickory_resolver::{proto::rr::rdata::TXT, TokioAsyncResolver}; use rauthy_api_types::atproto; -use rauthy_error::ErrorResponse; +use rauthy_common::constants::{COOKIE_UPSTREAM_CALLBACK, UPSTREAM_AUTH_CALLBACK_TIMEOUT_SECS}; +use rauthy_error::{ErrorResponse, ErrorResponseType}; +use tracing::error; -use crate::{app_state::AppState, database::DB}; +use crate::{api_cookie::ApiCookie, app_state::AppState, database::DB}; + +use super::auth_providers::{AuthProviderCallback, AuthProviderType}; pub type AtprotoClient = OAuthClient< DB, @@ -46,3 +51,57 @@ pub trait AtprotoCallback { payload: atproto::LoginRequest, ) -> Result<(Cookie<'_>, String, HeaderValue), ErrorResponse>; } + +impl AtprotoCallback for AuthProviderCallback { + async fn login_start( + data: &web::Data, + payload: atproto::LoginRequest, + ) -> Result<(Cookie<'_>, String, HeaderValue), ErrorResponse> { + let slf = Self { + callback_id: secure_random_alnum(32), + xsrf_token: secure_random_alnum(32), + typ: AuthProviderType::Custom, + + req_client_id: String::from("atproto"), + req_scopes: None, + req_redirect_uri: payload.redirect_uri.clone(), + req_state: payload.state.clone(), + req_nonce: None, + req_code_challenge: None, + req_code_challenge_method: None, + + provider_id: String::from("atproto"), + + pkce_challenge: payload.pkce_challenge, + }; + + slf.save().await?; + + let options = AuthorizeOptions { + state: Some(slf.callback_id.clone()), + ..Default::default() + }; + + match data.atproto.authorize(&payload.at_id, options).await { + Ok(location) => { + let cookie = ApiCookie::build( + COOKIE_UPSTREAM_CALLBACK, + &slf.callback_id, + UPSTREAM_AUTH_CALLBACK_TIMEOUT_SECS as i64, + ); + let header = + HeaderValue::from_str(&location).expect("Location HeaderValue to be correct"); + + Ok((cookie, slf.xsrf_token, header)) + } + Err(error) => { + error!(%error, "failed to build pushed authorization request for atproto"); + + Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "failed to build pushed authorization request for atproto", + )) + } + } + } +} diff --git a/src/models/src/entity/auth_providers.rs b/src/models/src/entity/auth_providers.rs index 636296c0..9774310f 100644 --- a/src/models/src/entity/auth_providers.rs +++ b/src/models/src/entity/auth_providers.rs @@ -743,7 +743,7 @@ impl AuthProviderCallback { } } - async fn save(&self) -> Result<(), ErrorResponse> { + pub async fn save(&self) -> Result<(), ErrorResponse> { DB::client() .put( Cache::AuthProviderCallback, From e75cb5c0c083cf4c0ceb44898b30b961505d107c Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 9 Dec 2024 21:16:21 +0000 Subject: [PATCH 08/22] declare ATProto callback request type --- src/api_types/src/atproto.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/api_types/src/atproto.rs b/src/api_types/src/atproto.rs index a2ab408a..5faf311a 100644 --- a/src/api_types/src/atproto.rs +++ b/src/api_types/src/atproto.rs @@ -1,4 +1,4 @@ -use rauthy_common::constants::RE_URI; +use rauthy_common::constants::{RE_ALNUM, RE_URI}; use serde::{Deserialize, Serialize}; use utoipa::ToSchema; use validator::Validate; @@ -19,3 +19,22 @@ pub struct LoginRequest { #[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))] pub pkce_challenge: String, } + +#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)] +pub struct CallbackRequest { + /// Validation: `[a-zA-Z0-9]` + #[validate(regex(path = "*RE_ALNUM", code = "[a-zA-Z0-9]"))] + pub state: String, + /// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$` + #[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))] + pub code: String, + /// Validation: `[a-zA-Z0-9]` + #[validate(regex(path = "*RE_ALNUM", code = "[a-zA-Z0-9]"))] + pub iss: Option, + /// Validation: `[a-zA-Z0-9]` + #[validate(regex(path = "*RE_ALNUM", code = "[a-zA-Z0-9]"))] + pub xsrf_token: String, + /// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$` + #[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))] + pub pkce_verifier: String, +} From a58b4ba2748c408193e3ace465dc79beb61c86bf Mon Sep 17 00:00:00 2001 From: avdb13 Date: Mon, 9 Dec 2024 23:13:29 +0000 Subject: [PATCH 09/22] extend ATProto callback trait --- src/models/src/entity/atproto.rs | 283 +++++++++++++++++++++++- src/models/src/entity/auth_providers.rs | 2 +- 2 files changed, 272 insertions(+), 13 deletions(-) diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs index 3b87036f..52081ead 100644 --- a/src/models/src/entity/atproto.rs +++ b/src/models/src/entity/atproto.rs @@ -1,19 +1,38 @@ -use actix_web::{cookie::Cookie, http::header::HeaderValue, web}; +use actix_web::{ + cookie::Cookie, + http::header::{self, HeaderValue}, + web, HttpRequest, +}; +use atrium_api::{agent::Agent, com::atproto::server::get_session, types::Object}; use atrium_identity::{ did::CommonDidResolver, handle::{AtprotoHandleResolver, DnsTxtResolver as DnsTxtResolverTrait}, }; -use atrium_oauth_client::{AuthorizeOptions, DefaultHttpClient, OAuthClient}; +use atrium_oauth_client::{AuthorizeOptions, CallbackParams, DefaultHttpClient, OAuthClient}; use cryptr::utils::secure_random_alnum; use hickory_resolver::{proto::rr::rdata::TXT, TokioAsyncResolver}; use rauthy_api_types::atproto; -use rauthy_common::constants::{COOKIE_UPSTREAM_CALLBACK, UPSTREAM_AUTH_CALLBACK_TIMEOUT_SECS}; +use rauthy_common::constants::{ + COOKIE_UPSTREAM_CALLBACK, PROVIDER_LINK_COOKIE, UPSTREAM_AUTH_CALLBACK_TIMEOUT_SECS, +}; use rauthy_error::{ErrorResponse, ErrorResponseType}; -use tracing::error; +use time::OffsetDateTime; +use tracing::{debug, error}; -use crate::{api_cookie::ApiCookie, app_state::AppState, database::DB}; +use crate::{ + api_cookie::ApiCookie, + app_state::AppState, + database::DB, + entity::{ + auth_codes::AuthCode, auth_providers::AuthProviderLinkCookie, clients::Client, users::User, + }, + AuthStep, AuthStepLoggedIn, +}; -use super::auth_providers::{AuthProviderCallback, AuthProviderType}; +use super::{ + auth_providers::{AuthProviderCallback, AuthProviderType}, + sessions::Session, +}; pub type AtprotoClient = OAuthClient< DB, @@ -46,17 +65,24 @@ impl DnsTxtResolverTrait for DnsTxtResolver { } pub trait AtprotoCallback { - async fn login_start( - data: &web::Data, + async fn login_start<'a>( + data: &'a web::Data, payload: atproto::LoginRequest, - ) -> Result<(Cookie<'_>, String, HeaderValue), ErrorResponse>; + ) -> Result<(Cookie<'a>, String, HeaderValue), ErrorResponse>; + + async fn login_finish<'a>( + data: &'a web::Data, + req: &'a HttpRequest, + payload: &'a atproto::CallbackRequest, + session: Session, + ) -> Result<(AuthStep, Cookie<'a>), ErrorResponse>; } impl AtprotoCallback for AuthProviderCallback { - async fn login_start( - data: &web::Data, + async fn login_start<'a>( + data: &'a web::Data, payload: atproto::LoginRequest, - ) -> Result<(Cookie<'_>, String, HeaderValue), ErrorResponse> { + ) -> Result<(Cookie<'a>, String, HeaderValue), ErrorResponse> { let slf = Self { callback_id: secure_random_alnum(32), xsrf_token: secure_random_alnum(32), @@ -104,4 +130,237 @@ impl AtprotoCallback for AuthProviderCallback { } } } + + async fn login_finish<'a>( + data: &'a web::Data, + req: &'a HttpRequest, + payload: &'a atproto::CallbackRequest, + session: Session, + ) -> Result<(AuthStep, Cookie<'a>), ErrorResponse> { + // the callback id for the cache should be inside the encrypted cookie + let callback_id = ApiCookie::from_req(req, COOKIE_UPSTREAM_CALLBACK).ok_or_else(|| { + ErrorResponse::new( + ErrorResponseType::Forbidden, + "Missing encrypted callback cookie", + ) + })?; + + // validate state + if callback_id != payload.state { + AuthProviderCallback::delete(callback_id).await?; + + error!("`state` does not match"); + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "`state` does not match", + )); + } + debug!("callback state is valid"); + + // validate csrf token + let slf = AuthProviderCallback::find(callback_id).await?; + if slf.xsrf_token != payload.xsrf_token { + AuthProviderCallback::delete(slf.callback_id).await?; + + error!("invalid CSRF token"); + return Err(ErrorResponse::new( + ErrorResponseType::Unauthorized, + "invalid CSRF token", + )); + } + debug!("callback csrf token is valid"); + + // request is valid -> fetch token for the user + let params = CallbackParams { + code: payload.code.clone(), + state: Some(payload.state.clone()), + iss: payload.iss.clone(), + }; + // return early if we got any error + let (session_manager, _) = data.atproto.callback(params).await.map_err(|error| { + error!(%error, "failed to complete authorization callback for atproto"); + + ErrorResponse::new( + ErrorResponseType::BadRequest, + "failed to complete authorization callback for atproto", + ) + })?; + + let agent = Agent::new(session_manager); + + let (did, email, email_confirmed) = match agent.api.com.atproto.server.get_session().await { + Ok(Object { + data: + get_session::OutputData { + did, + email, + email_confirmed, + .. + }, + .. + }) => (did, email, email_confirmed), + Err(error) => { + error!(%error, "failed to get session for atproto"); + + return Err(ErrorResponse::new( + ErrorResponseType::Internal, + "failed to get session for atproto", + )); + } + }; + + let link_cookie = ApiCookie::from_req(req, PROVIDER_LINK_COOKIE) + .and_then(|value| AuthProviderLinkCookie::try_from(value.as_str()).ok()); + + if email.is_none() { + todo!() + } + + let claims_user_id = did.to_string(); + + let user_opt = match User::find_by_federation("atproto", &claims_user_id).await { + Ok(user) => { + debug!( + "found already existing user by federation lookup: {:?}", + user + ); + Some(user) + } + Err(_) => { + debug!("did not find already existing user by federation lookup - making sure email does not exist"); + + if let Ok(mut user) = User::find_by_email(email.as_ref().unwrap().to_string()).await + { + if let Some(link) = link_cookie.as_ref() { + if link.provider_id != "atproto" { + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "bad provider_id in link cookie".to_string(), + )); + } + + if link.user_id != user.id { + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "bad user_id in link cookie".to_string(), + )); + } + + if link.user_email != user.email { + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "Invalid E-Mail".to_string(), + )); + } + + user.auth_provider_id = Some("atproto".to_owned()); + user.federation_uid = Some(claims_user_id.clone()); + + Some(user) + } else { + return Err(ErrorResponse::new(ErrorResponseType::Forbidden, format!( + "User with email '{}' already exists but is not linked to this provider.", + user.email + ))); + } + } else { + None + } + } + }; + debug!("user_opt:\n{:?}", user_opt); + + let now = OffsetDateTime::now_utc().unix_timestamp(); + let user = if let Some(mut user) = user_opt { + let mut old_email = None; + let mut forbidden_error = None; + + if user.federation_uid.is_none() + || user.federation_uid.as_deref() != Some(&claims_user_id) + { + forbidden_error = Some("non-federated user or ID mismatch"); + } + + if user.auth_provider_id.as_deref() != Some("atproto") { + forbidden_error = Some("invalid login from wrong auth provider"); + } + + if let Some(err) = forbidden_error { + user.last_failed_login = Some(now); + user.failed_login_attempts = + Some(user.failed_login_attempts.unwrap_or_default() + 1); + user.save(old_email).await?; + + return Err(ErrorResponse::new( + ErrorResponseType::Forbidden, + err.to_string(), + )); + } + + if Some(user.email.as_str()) != email.as_deref() { + old_email = Some(user.email); + user.email = email.as_ref().unwrap().to_string(); + } + + user.last_login = Some(now); + user.last_failed_login = None; + user.failed_login_attempts = None; + + user.save(old_email).await?; + user + } else { + let new_user = User { + email: email.as_ref().unwrap().to_string(), + given_name: "N/A".to_string(), + family_name: None, + roles: Default::default(), + enabled: true, + email_verified: email_confirmed.unwrap_or(false), + last_login: Some(now), + language: Default::default(), + auth_provider_id: Some("atproto".to_owned()), + federation_uid: Some(claims_user_id.to_string()), + ..Default::default() + }; + User::create_federated(new_user).await? + }; + + user.check_enabled()?; + user.check_expired()?; + + if link_cookie.is_some() { + return Ok(( + AuthStep::ProviderLink, + AuthProviderLinkCookie::deletion_cookie(), + )); + } + + let client = Client::default(); + + let code = AuthCode::new( + user.id.clone(), + "atproto".to_owned(), + Some(session.id.clone()), + slf.req_code_challenge, + slf.req_code_challenge_method, + slf.req_nonce, + vec!["atproto".to_owned()], + client.auth_code_lifetime, + ); + code.save().await?; + + let auth_step = AuthStep::LoggedIn(AuthStepLoggedIn { + user_id: user.id, + email: user.email, + header_loc: ( + header::LOCATION, + HeaderValue::from_str(&(format!("{}?code={}", slf.req_redirect_uri, code.id)))?, + ), + header_csrf: Session::get_csrf_header(&session.csrf_token), + header_origin: None, + }); + + let cookie = ApiCookie::build(COOKIE_UPSTREAM_CALLBACK, "", 0); + Ok((auth_step, cookie)) + } } diff --git a/src/models/src/entity/auth_providers.rs b/src/models/src/entity/auth_providers.rs index 9774310f..f50f3e98 100644 --- a/src/models/src/entity/auth_providers.rs +++ b/src/models/src/entity/auth_providers.rs @@ -729,7 +729,7 @@ impl AuthProviderCallback { Ok(()) } - async fn find(callback_id: String) -> Result { + pub async fn find(callback_id: String) -> Result { let opt: Option = DB::client() .get(Cache::AuthProviderCallback, callback_id) .await?; From c865708b700b5947467e02c150749da3c0430e78 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 10 Dec 2024 09:40:00 +0000 Subject: [PATCH 10/22] add ATProto client metadata handler --- src/api/src/atproto.rs | 25 +++++++++++++++++++++++++ src/api/src/lib.rs | 1 + 2 files changed, 26 insertions(+) create mode 100644 src/api/src/atproto.rs diff --git a/src/api/src/atproto.rs b/src/api/src/atproto.rs new file mode 100644 index 00000000..b39c2308 --- /dev/null +++ b/src/api/src/atproto.rs @@ -0,0 +1,25 @@ +use actix_web::{ + get, + http::header::{self, HeaderValue}, + web, HttpResponse, +}; +use rauthy_error::ErrorResponse; +use rauthy_models::app_state::AppState; + +#[utoipa::path( + get, + path = "/atproto/client_metadata", + tag = "atproto", + responses( + (status = 200, description = "OK"), + ), +)] +#[get("/atproto/client_metadata")] +pub async fn get_client_metadata(data: web::Data) -> Result { + Ok(HttpResponse::Ok() + .insert_header(( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_str("*").unwrap(), + )) + .json(&data.atproto.client_metadata)) +} diff --git a/src/api/src/lib.rs b/src/api/src/lib.rs index 07aff120..5a1ceccf 100644 --- a/src/api/src/lib.rs +++ b/src/api/src/lib.rs @@ -17,6 +17,7 @@ use rust_embed::RustEmbed; use tracing::error; pub mod api_keys; +pub mod atproto; pub mod auth_providers; pub mod blacklist; pub mod clients; From 103c86ebce87bf8fe34434c17bff027691ea6f56 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 10 Dec 2024 09:43:13 +0000 Subject: [PATCH 11/22] add ATProto login handler --- src/api/src/atproto.rs | 42 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/src/api/src/atproto.rs b/src/api/src/atproto.rs index b39c2308..9c568ed5 100644 --- a/src/api/src/atproto.rs +++ b/src/api/src/atproto.rs @@ -1,10 +1,18 @@ use actix_web::{ get, - http::header::{self, HeaderValue}, - web, HttpResponse, + http::header::{self, HeaderValue, LOCATION}, + post, + web::{self, Json}, + HttpResponse, }; +use rauthy_api_types::atproto; use rauthy_error::ErrorResponse; -use rauthy_models::app_state::AppState; +use rauthy_models::{ + app_state::AppState, + entity::{atproto::AtprotoCallback, auth_providers::AuthProviderCallback}, +}; + +use crate::ReqPrincipal; #[utoipa::path( get, @@ -23,3 +31,31 @@ pub async fn get_client_metadata(data: web::Data) -> Result, + payload: Json, + principal: ReqPrincipal, +) -> Result { + principal.validate_session_auth_or_init()?; + + let payload = payload.into_inner(); + let (cookie, xsrf_token, location) = + ::login_start(&data, payload).await?; + + Ok(HttpResponse::Accepted() + .insert_header((LOCATION, location)) + .cookie(cookie) + .body(xsrf_token)) +} From 6a1513e65d826cf0b2b1281e67f81d40a4cad7a6 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 10 Dec 2024 09:47:39 +0000 Subject: [PATCH 12/22] add ATProto callback handler --- src/api/src/atproto.rs | 54 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/src/api/src/atproto.rs b/src/api/src/atproto.rs index 9c568ed5..900e32db 100644 --- a/src/api/src/atproto.rs +++ b/src/api/src/atproto.rs @@ -2,17 +2,17 @@ use actix_web::{ get, http::header::{self, HeaderValue, LOCATION}, post, - web::{self, Json}, - HttpResponse, + web::{self, Json, Query}, + HttpRequest, HttpResponse, }; -use rauthy_api_types::atproto; -use rauthy_error::ErrorResponse; +use rauthy_api_types::{atproto, auth_providers::ProviderLookupResponse}; +use rauthy_error::{ErrorResponse, ErrorResponseType}; use rauthy_models::{ app_state::AppState, entity::{atproto::AtprotoCallback, auth_providers::AuthProviderCallback}, }; -use crate::ReqPrincipal; +use crate::{map_auth_step, ReqPrincipal}; #[utoipa::path( get, @@ -59,3 +59,47 @@ pub async fn post_login( .cookie(cookie) .body(xsrf_token)) } + +#[utoipa::path( + post, + path = "/atproto/callback", + tag = "atproto", + responses( + (status = 200, description = "OK", body = ProviderLookupResponse), + (status = 400, description = "BadRequest", body = ErrorResponse), + (status = 404, description = "NotFound", body = ErrorResponse), + ), +)] +#[post("/atproto/callback")] +#[tracing::instrument( + name = "post_provider_callback", + skip_all, fields(callback_id = payload.state) +)] +pub async fn post_callback( + data: web::Data, + req: HttpRequest, + payload: Query, + principal: ReqPrincipal, +) -> Result { + principal.validate_session_auth_or_init()?; + + let payload = payload.into_inner(); + let session = principal.get_session()?; + let (auth_step, cookie) = ::login_finish( + &data, + &req, + &payload, + session.clone(), + ) + .await?; + + let mut resp = map_auth_step(auth_step, &req).await?; + resp.add_cookie(&cookie).map_err(|err| { + ErrorResponse::new( + ErrorResponseType::Internal, + format!("Error adding cookie after map_auth_step: {}", err), + ) + })?; + + Ok(resp) +} From 12d0577278d9ebc1a1ab596331c5824d2016c8d7 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 10 Dec 2024 09:49:11 +0000 Subject: [PATCH 13/22] add handlers to router --- src/bin/src/main.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/bin/src/main.rs b/src/bin/src/main.rs index 6828eb64..3e5ac03a 100644 --- a/src/bin/src/main.rs +++ b/src/bin/src/main.rs @@ -16,8 +16,8 @@ use rauthy_common::utils::UseDummyAddress; use rauthy_common::{is_hiqlite, is_sqlite, password_hasher}; use rauthy_handlers::openapi::ApiDoc; use rauthy_handlers::{ - api_keys, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc, roles, - scopes, sessions, users, + api_keys, atproto, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc, + roles, scopes, sessions, users, }; use rauthy_middlewares::csrf_protection::CsrfProtectionMiddleware; use rauthy_middlewares::ip_blacklist::RauthyIpBlacklistMiddleware; @@ -556,7 +556,10 @@ async fn actix_main(app_state: web::Data) -> std::io::Result<()> { .service(oidc::get_well_known) .service(generic::get_health) .service(generic::get_ready) - .service(generic::get_static_assets), + .service(generic::get_static_assets) + .service(atproto::get_client_metadata) + .service(atproto::post_login) + .service(atproto::post_callback), ), ); From b453c846eb67806feab2d9eb5062f5c185c4db0d Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 10 Dec 2024 09:50:35 +0000 Subject: [PATCH 14/22] fix clippy warning --- src/models/src/entity/atproto.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs index 52081ead..b485f8b5 100644 --- a/src/models/src/entity/atproto.rs +++ b/src/models/src/entity/atproto.rs @@ -1,3 +1,5 @@ +use std::future::Future; + use actix_web::{ cookie::Cookie, http::header::{self, HeaderValue}, @@ -65,17 +67,17 @@ impl DnsTxtResolverTrait for DnsTxtResolver { } pub trait AtprotoCallback { - async fn login_start<'a>( + fn login_start<'a>( data: &'a web::Data, payload: atproto::LoginRequest, - ) -> Result<(Cookie<'a>, String, HeaderValue), ErrorResponse>; + ) -> impl Future, String, HeaderValue), ErrorResponse>>; - async fn login_finish<'a>( + fn login_finish<'a>( data: &'a web::Data, req: &'a HttpRequest, payload: &'a atproto::CallbackRequest, session: Session, - ) -> Result<(AuthStep, Cookie<'a>), ErrorResponse>; + ) -> impl Future), ErrorResponse>>; } impl AtprotoCallback for AuthProviderCallback { From f45f5e815db72c9199b4ecf200c3d674d1a6a725 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 10 Dec 2024 15:42:48 +0000 Subject: [PATCH 15/22] extend OpenAPI schema with ATProto models and endpoints --- src/api/src/openapi.rs | 15 +++++++++++++-- src/models/src/entity/atproto.rs | 11 +++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/api/src/openapi.rs b/src/api/src/openapi.rs index 1dc921e6..ecb48271 100644 --- a/src/api/src/openapi.rs +++ b/src/api/src/openapi.rs @@ -1,8 +1,11 @@ use crate::{ - api_keys, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc, roles, - scopes, sessions, users, + api_keys, atproto, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc, + roles, scopes, sessions, users, }; use actix_web::web; +use rauthy_api_types::atproto::{ + CallbackRequest as AtprotoCallbackRequest, LoginRequest as AtprotoLoginRequest, +}; use rauthy_api_types::{ api_keys::*, auth_providers::*, blacklist::*, clients::*, events::*, fed_cm::*, generic::*, groups::*, oidc::*, roles::*, scopes::*, sessions::*, users::*, @@ -26,6 +29,10 @@ use utoipa::{openapi, OpenApi}; api_keys::get_api_key_test, api_keys::put_api_key_secret, + atproto::get_client_metadata, + atproto::post_login, + atproto::post_callback, + auth_providers::post_providers, auth_providers::post_provider, auth_providers::post_provider_lookup, @@ -152,6 +159,7 @@ use utoipa::{openapi, OpenApi}; components( schemas( entity::colors::Colors, + entity::atproto::DnsTxtResolver, entity::fed_cm::FedCMAccount, entity::fed_cm::FedCMAccounts, entity::fed_cm::FedCMIdPBranding, @@ -186,6 +194,8 @@ use utoipa::{openapi, OpenApi}; ErrorResponseType, ApiKeyRequest, + AtprotoCallbackRequest, + AtprotoLoginRequest, AuthCodeRequest, AuthRequest, IpBlacklistRequest, @@ -294,6 +304,7 @@ use utoipa::{openapi, OpenApi}; (name = "generic", description = "Generic endpoints"), (name = "webid", description = "WebID endpoints"), (name = "fed_cm", description = "Experimental FedCM endpoints"), + (name = "atproto", description = "Experimental ATProto endpoints"), (name = "deprecated", description = "Deprecated endpoints - will be removed in a future version"), ), )] diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs index b485f8b5..55439086 100644 --- a/src/models/src/entity/atproto.rs +++ b/src/models/src/entity/atproto.rs @@ -20,6 +20,7 @@ use rauthy_common::constants::{ use rauthy_error::{ErrorResponse, ErrorResponseType}; use time::OffsetDateTime; use tracing::{debug, error}; +use utoipa::{PartialSchema, ToSchema}; use crate::{ api_cookie::ApiCookie, @@ -56,6 +57,16 @@ impl Default for DnsTxtResolver { } } +impl ToSchema for DnsTxtResolver {} + +impl PartialSchema for DnsTxtResolver { + fn schema() -> utoipa::openapi::RefOr { + utoipa::openapi::RefOr::T(utoipa::openapi::Schema::Object( + utoipa::openapi::ObjectBuilder::new().build(), + )) + } +} + impl DnsTxtResolverTrait for DnsTxtResolver { async fn resolve( &self, From 1d2d235cd2d5b986615bf2955f445fda5c879895 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 17 Dec 2024 15:30:46 +0000 Subject: [PATCH 16/22] convert `AtprotoClient` type to newtype struct --- src/models/src/app_state.rs | 58 ++----------------- src/models/src/entity/atproto.rs | 99 ++++++++++++++++++++++++++++---- 2 files changed, 92 insertions(+), 65 deletions(-) diff --git a/src/models/src/app_state.rs b/src/models/src/app_state.rs index ed7a8fc9..8df1354e 100644 --- a/src/models/src/app_state.rs +++ b/src/models/src/app_state.rs @@ -1,16 +1,9 @@ -use crate::database::DB; use crate::email::EMail; -use crate::entity::atproto::{AtprotoClient, DnsTxtResolver}; +use crate::entity::atproto::AtprotoClient; use crate::events::event::Event; use crate::events::ip_blacklist_handler::IpBlacklistReq; use crate::events::listener::EventRouterMsg; use crate::ListenScheme; -use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}; -use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig}; -use atrium_oauth_client::{ - AtprotoClientMetadata, AuthMethod, DefaultHttpClient, GrantType, KnownScope, OAuthClient, - OAuthClientConfig, OAuthResolverConfig, Scope, -}; use rauthy_common::constants::PROXY_MODE; use std::env; use std::sync::Arc; @@ -39,7 +32,7 @@ pub struct AppState { pub tx_events_router: flume::Sender, pub tx_ip_blacklist: flume::Sender, pub webauthn: Arc, - pub atproto: Arc, + pub atproto: AtprotoClient, } impl AppState { @@ -163,51 +156,8 @@ impl AppState { .rp_name(&rp_name); let webauthn = Arc::new(builder.build().expect("Invalid configuration")); - let atproto = { - let http_client = Arc::new(DefaultHttpClient::default()); - - let listen_scheme = match listen_scheme { - ListenScheme::Http | ListenScheme::UnixHttp => "http", - ListenScheme::Https | ListenScheme::HttpHttps | ListenScheme::UnixHttps => "https", - }; - - let client_metadata = AtprotoClientMetadata { - client_id: format!( - "{listen_scheme}://{public_url}/auth/v1/atproto/client_metadata" - ), - client_uri: format!("{listen_scheme}://{public_url}"), - redirect_uris: vec![format!( - "{listen_scheme}://{public_url}/auth/v1/atproto/callback" - )], - token_endpoint_auth_method: AuthMethod::None, - grant_types: vec![GrantType::AuthorizationCode], - scopes: vec![Scope::Known(KnownScope::Atproto)], - jwks_uri: None, - token_endpoint_auth_signing_alg: None, - }; - - Arc::new( - OAuthClient::new(OAuthClientConfig { - client_metadata: client_metadata.clone(), - keys: None, - resolver: OAuthResolverConfig { - did_resolver: CommonDidResolver::new(CommonDidResolverConfig { - plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), - http_client: http_client.clone(), - }), - handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { - dns_txt_resolver: DnsTxtResolver::default(), - http_client: http_client.clone(), - }), - authorization_server_metadata: Default::default(), - protected_resource_metadata: Default::default(), - }, - state_store: DB, - session_store: DB, - }) - .expect("failed to initialize atproto client"), - ) - }; + let atproto = AtprotoClient::new(&listen_scheme, &public_url) + .expect("failed to initialize atproto client"); Ok(Self { public_url, diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs index 55439086..19459c92 100644 --- a/src/models/src/entity/atproto.rs +++ b/src/models/src/entity/atproto.rs @@ -1,4 +1,4 @@ -use std::future::Future; +use std::{future::Future, ops::Deref, sync::Arc}; use actix_web::{ cookie::Cookie, @@ -7,10 +7,15 @@ use actix_web::{ }; use atrium_api::{agent::Agent, com::atproto::server::get_session, types::Object}; use atrium_identity::{ - did::CommonDidResolver, - handle::{AtprotoHandleResolver, DnsTxtResolver as DnsTxtResolverTrait}, + did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}, + handle::{ + AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver as DnsTxtResolverTrait, + }, +}; +use atrium_oauth_client::{ + AtprotoClientMetadata, AuthMethod, AuthorizeOptions, CallbackParams, DefaultHttpClient, + GrantType, KnownScope, OAuthClient, OAuthClientConfig, OAuthResolverConfig, Scope, }; -use atrium_oauth_client::{AuthorizeOptions, CallbackParams, DefaultHttpClient, OAuthClient}; use cryptr::utils::secure_random_alnum; use hickory_resolver::{proto::rr::rdata::TXT, TokioAsyncResolver}; use rauthy_api_types::atproto; @@ -29,7 +34,7 @@ use crate::{ entity::{ auth_codes::AuthCode, auth_providers::AuthProviderLinkCookie, clients::Client, users::User, }, - AuthStep, AuthStepLoggedIn, + AuthStep, AuthStepLoggedIn, ListenScheme, }; use super::{ @@ -37,12 +42,84 @@ use super::{ sessions::Session, }; -pub type AtprotoClient = OAuthClient< - DB, - DB, - CommonDidResolver, - AtprotoHandleResolver, ->; +#[derive(Clone)] +pub struct AtprotoClient( + Arc< + OAuthClient< + DB, + DB, + CommonDidResolver, + AtprotoHandleResolver, + >, + >, +); + +impl std::fmt::Debug for AtprotoClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("AtprotoClient").finish() + } +} + +impl AtprotoClient { + pub fn new( + listen_scheme: &ListenScheme, + public_url: &str, + ) -> Result { + let http_client = Arc::new(DefaultHttpClient::default()); + + let listen_scheme = match listen_scheme { + ListenScheme::Http | ListenScheme::UnixHttp => "http", + ListenScheme::Https | ListenScheme::HttpHttps | ListenScheme::UnixHttps => "https", + }; + + let client_metadata = AtprotoClientMetadata { + client_id: format!("{listen_scheme}://{public_url}/auth/v1/atproto/client_metadata"), + client_uri: String::new(), + redirect_uris: vec![format!( + "{listen_scheme}://{public_url}/auth/v1/atproto/callback" + )], + token_endpoint_auth_method: AuthMethod::None, + grant_types: vec![GrantType::AuthorizationCode], + scopes: vec![Scope::Known(KnownScope::Atproto)], + jwks_uri: None, + token_endpoint_auth_signing_alg: None, + }; + + let config = OAuthClientConfig { + client_metadata: client_metadata.clone(), + keys: None, + resolver: OAuthResolverConfig { + did_resolver: CommonDidResolver::new(CommonDidResolverConfig { + plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), + http_client: http_client.clone(), + }), + handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { + dns_txt_resolver: DnsTxtResolver::default(), + http_client: http_client.clone(), + }), + authorization_server_metadata: Default::default(), + protected_resource_metadata: Default::default(), + }, + state_store: DB, + session_store: DB, + }; + + OAuthClient::new(config).map(Arc::new).map(AtprotoClient) + } +} + +impl Deref for AtprotoClient { + type Target = OAuthClient< + DB, + DB, + CommonDidResolver, + AtprotoHandleResolver, + >; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} pub struct DnsTxtResolver { resolver: TokioAsyncResolver, From 83b48ca74f5501d6c2a37c2990ee8014a66d715b Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 17 Dec 2024 15:48:51 +0000 Subject: [PATCH 17/22] blacklist upstream provider name --- .../admin/providers/ProviderConfig.svelte | 2 +- .../admin/providers/ProviderTileAddNew.svelte | 2 +- src/api/src/auth_providers.rs | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/frontend/src/components/admin/providers/ProviderConfig.svelte b/frontend/src/components/admin/providers/ProviderConfig.svelte index 571faa64..fdc4f17a 100644 --- a/frontend/src/components/admin/providers/ProviderConfig.svelte +++ b/frontend/src/components/admin/providers/ProviderConfig.svelte @@ -58,7 +58,7 @@ token_endpoint: yup.string().url(), userinfo_endpoint: yup.string().url(), - name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128"), + name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128").not(["atproto"], "Cannot be reserved"), client_id: yup.string().trim().matches(REGEX_URI, "Can only contain URI safe characters, length max: 128"), client_secret: yup.string().trim().max(256, "Max 256 characters"), scope: yup.string().trim().matches(REGEX_PROVIDER_SCOPE, "Can only contain: 'a-zA-Z0-9-_/ ', length max: 128"), diff --git a/frontend/src/components/admin/providers/ProviderTileAddNew.svelte b/frontend/src/components/admin/providers/ProviderTileAddNew.svelte index 07877039..fe68f5e6 100644 --- a/frontend/src/components/admin/providers/ProviderTileAddNew.svelte +++ b/frontend/src/components/admin/providers/ProviderTileAddNew.svelte @@ -79,7 +79,7 @@ token_endpoint: yup.string().url().required('Required'), userinfo_endpoint: yup.string().url().required('Required'), - name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128").required('Required'), + name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128").not(["atproto"], "Cannot be reserved").required('Required'), client_id: yup.string().trim().matches(REGEX_URI, "Can only contain URI safe characters, length max: 128").required('Required'), client_secret: yup.string().trim().max(256, "Max 256 characters"), scope: yup.string().trim().matches(REGEX_PROVIDER_SCOPE, "Can only contain: 'a-zA-Z0-9-_/ ', length max: 128").required('Required'), diff --git a/src/api/src/auth_providers.rs b/src/api/src/auth_providers.rs index aa1f257e..20d15ee3 100644 --- a/src/api/src/auth_providers.rs +++ b/src/api/src/auth_providers.rs @@ -68,6 +68,13 @@ pub async fn post_provider( ) -> Result { principal.validate_admin_session()?; + if payload.name == "atproto" { + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "Must not contain a reserved name".to_string(), + )); + } + if !payload.use_pkce && payload.client_secret.is_none() { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, @@ -261,6 +268,13 @@ pub async fn put_provider( ) -> Result { principal.validate_admin_session()?; + if payload.name == "atproto" { + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "Must not contain a reserved name".to_string(), + )); + } + if !payload.use_pkce && payload.client_secret.is_none() { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, From fe4011d49b0de389216d84d5f4828149d08381bb Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 17 Dec 2024 15:53:12 +0000 Subject: [PATCH 18/22] do not allocate error message strings --- src/models/src/entity/atproto.rs | 11 ++++----- src/models/src/entity/auth_providers.rs | 32 ++++++++----------------- 2 files changed, 14 insertions(+), 29 deletions(-) diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs index 19459c92..98d324df 100644 --- a/src/models/src/entity/atproto.rs +++ b/src/models/src/entity/atproto.rs @@ -325,21 +325,21 @@ impl AtprotoCallback for AuthProviderCallback { if link.provider_id != "atproto" { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, - "bad provider_id in link cookie".to_string(), + "bad provider_id in link cookie", )); } if link.user_id != user.id { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, - "bad user_id in link cookie".to_string(), + "bad user_id in link cookie", )); } if link.user_email != user.email { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, - "Invalid E-Mail".to_string(), + "Invalid E-Mail", )); } @@ -381,10 +381,7 @@ impl AtprotoCallback for AuthProviderCallback { Some(user.failed_login_attempts.unwrap_or_default() + 1); user.save(old_email).await?; - return Err(ErrorResponse::new( - ErrorResponseType::Forbidden, - err.to_string(), - )); + return Err(ErrorResponse::new(ErrorResponseType::Forbidden, err)); } if Some(user.email.as_str()) != email.as_deref() { diff --git a/src/models/src/entity/auth_providers.rs b/src/models/src/entity/auth_providers.rs index f50f3e98..0271353d 100644 --- a/src/models/src/entity/auth_providers.rs +++ b/src/models/src/entity/auth_providers.rs @@ -1205,14 +1205,11 @@ impl AuthProviderIdClaims<'_> { let _header = parts.next().ok_or_else(|| { ErrorResponse::new( ErrorResponseType::BadRequest, - "incorrect ID did not contain claims".to_string(), + "incorrect ID did not contain claims", ) })?; let claims = parts.next().ok_or_else(|| { - ErrorResponse::new( - ErrorResponseType::BadRequest, - "ID token was unsigned".to_string(), - ) + ErrorResponse::new(ErrorResponseType::BadRequest, "ID token was unsigned") })?; debug!("upstream ID token claims:\n{}", claims); let json_bytes = base64_url_no_pad_decode(claims)?; @@ -1227,10 +1224,7 @@ impl AuthProviderIdClaims<'_> { if self.email.is_none() { let err = "No `email` in ID token claims. This is a mandatory claim"; error!("{}", err); - return Err(ErrorResponse::new( - ErrorResponseType::BadRequest, - err.to_string(), - )); + return Err(ErrorResponse::new(ErrorResponseType::BadRequest, err)); } let claims_user_id = if let Some(sub) = &self.sub { @@ -1242,10 +1236,7 @@ impl AuthProviderIdClaims<'_> { } else { let err = "Cannot find any user id in the response"; error!("{}", err); - return Err(ErrorResponse::new( - ErrorResponseType::BadRequest, - err.to_string(), - )); + return Err(ErrorResponse::new(ErrorResponseType::BadRequest, err)); } // We need to create a real string here, since we don't know what json type we get. // Any json number would become a String too, which is what we need for compatibility. @@ -1274,7 +1265,7 @@ impl AuthProviderIdClaims<'_> { if link.provider_id != provider.id { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, - "bad provider_id in link cookie".to_string(), + "bad provider_id in link cookie", )); } @@ -1284,7 +1275,7 @@ impl AuthProviderIdClaims<'_> { // multiple accounts. return Err(ErrorResponse::new( ErrorResponseType::BadRequest, - "bad user_id in link cookie".to_string(), + "bad user_id in link cookie", )); } @@ -1292,7 +1283,7 @@ impl AuthProviderIdClaims<'_> { if link.user_email != user.email { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, - "Invalid E-Mail".to_string(), + "Invalid E-Mail", )); } @@ -1321,7 +1312,7 @@ impl AuthProviderIdClaims<'_> { if provider.admin_claim_value.is_none() { return Err(ErrorResponse::new( ErrorResponseType::Internal, - "Misconfigured Auth Provider - admin claim path without value".to_string(), + "Misconfigured Auth Provider - admin claim path without value", )); } @@ -1362,7 +1353,7 @@ impl AuthProviderIdClaims<'_> { if provider.mfa_claim_value.is_none() { return Err(ErrorResponse::new( ErrorResponseType::Internal, - "Misconfigured Auth Provider - mfa claim path without value".to_string(), + "Misconfigured Auth Provider - mfa claim path without value", )); } @@ -1422,10 +1413,7 @@ impl AuthProviderIdClaims<'_> { Some(user.failed_login_attempts.unwrap_or_default() + 1); user.save(old_email).await?; - return Err(ErrorResponse::new( - ErrorResponseType::Forbidden, - err.to_string(), - )); + return Err(ErrorResponse::new(ErrorResponseType::Forbidden, err)); } // check / update email From b475fcc6dc68cae7242bbe7c42a6600ee09a51df Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 17 Dec 2024 15:56:41 +0000 Subject: [PATCH 19/22] use `chrono` instead of `time` --- src/models/src/entity/atproto.rs | 3 +-- src/models/src/entity/auth_providers.rs | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs index 98d324df..4afa806b 100644 --- a/src/models/src/entity/atproto.rs +++ b/src/models/src/entity/atproto.rs @@ -23,7 +23,6 @@ use rauthy_common::constants::{ COOKIE_UPSTREAM_CALLBACK, PROVIDER_LINK_COOKIE, UPSTREAM_AUTH_CALLBACK_TIMEOUT_SECS, }; use rauthy_error::{ErrorResponse, ErrorResponseType}; -use time::OffsetDateTime; use tracing::{debug, error}; use utoipa::{PartialSchema, ToSchema}; @@ -360,7 +359,7 @@ impl AtprotoCallback for AuthProviderCallback { }; debug!("user_opt:\n{:?}", user_opt); - let now = OffsetDateTime::now_utc().unix_timestamp(); + let now = chrono::Utc::now().timestamp(); let user = if let Some(mut user) = user_opt { let mut old_email = None; let mut forbidden_error = None; diff --git a/src/models/src/entity/auth_providers.rs b/src/models/src/entity/auth_providers.rs index 0271353d..0fbfc368 100644 --- a/src/models/src/entity/auth_providers.rs +++ b/src/models/src/entity/auth_providers.rs @@ -48,7 +48,6 @@ use std::borrow::Cow; use std::fmt::Write; use std::str::FromStr; use std::time::Duration; -use time::OffsetDateTime; use tracing::{debug, error}; use utoipa::ToSchema; @@ -1388,7 +1387,7 @@ impl AuthProviderIdClaims<'_> { } } - let now = OffsetDateTime::now_utc().unix_timestamp(); + let now = chrono::Utc::now().timestamp(); let user = if let Some(mut user) = user_opt { let mut old_email = None; let mut forbidden_error = None; From c67e171a7451aab7eb8783f5ff967e5f6acc56b1 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 17 Dec 2024 15:59:58 +0000 Subject: [PATCH 20/22] provide meaning default for `given_name` --- src/models/src/entity/atproto.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs index 4afa806b..394b320b 100644 --- a/src/models/src/entity/atproto.rs +++ b/src/models/src/entity/atproto.rs @@ -397,7 +397,7 @@ impl AtprotoCallback for AuthProviderCallback { } else { let new_user = User { email: email.as_ref().unwrap().to_string(), - given_name: "N/A".to_string(), + given_name: "Unknown".to_string(), family_name: None, roles: Default::default(), enabled: true, From f7129720d7172c2d4569e28fb94339a7d5013180 Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 17 Dec 2024 17:53:16 +0000 Subject: [PATCH 21/22] provide atproto configuration --- rauthy.cfg | 8 ++++++++ src/api/src/atproto.rs | 17 +++++++++++++++++ src/common/src/constants.rs | 5 +++++ 3 files changed, 30 insertions(+) diff --git a/rauthy.cfg b/rauthy.cfg index 250b314c..19314149 100644 --- a/rauthy.cfg +++ b/rauthy.cfg @@ -213,6 +213,14 @@ SUSPICIOUS_REQUESTS_BLACKLIST=1440 # default: false SUSPICIOUS_REQUESTS_LOG=true +##################################### +############## ATPROTO ############## +##################################### + +# Set to `true` to enable the atproto client. +# default: false +# ATPROTO_ENABLE=false + ##################################### ############# BACKUPS ############### ##################################### diff --git a/src/api/src/atproto.rs b/src/api/src/atproto.rs index 900e32db..95adf85d 100644 --- a/src/api/src/atproto.rs +++ b/src/api/src/atproto.rs @@ -6,6 +6,7 @@ use actix_web::{ HttpRequest, HttpResponse, }; use rauthy_api_types::{atproto, auth_providers::ProviderLookupResponse}; +use rauthy_common::constants::ATPROTO_ENABLE; use rauthy_error::{ErrorResponse, ErrorResponseType}; use rauthy_models::{ app_state::AppState, @@ -24,6 +25,8 @@ use crate::{map_auth_step, ReqPrincipal}; )] #[get("/atproto/client_metadata")] pub async fn get_client_metadata(data: web::Data) -> Result { + is_atproto_enabled()?; + Ok(HttpResponse::Ok() .insert_header(( header::ACCESS_CONTROL_ALLOW_ORIGIN, @@ -48,6 +51,7 @@ pub async fn post_login( payload: Json, principal: ReqPrincipal, ) -> Result { + is_atproto_enabled()?; principal.validate_session_auth_or_init()?; let payload = payload.into_inner(); @@ -81,6 +85,7 @@ pub async fn post_callback( payload: Query, principal: ReqPrincipal, ) -> Result { + is_atproto_enabled()?; principal.validate_session_auth_or_init()?; let payload = payload.into_inner(); @@ -103,3 +108,15 @@ pub async fn post_callback( Ok(resp) } + +#[inline(always)] +fn is_atproto_enabled() -> Result<(), ErrorResponse> { + if *ATPROTO_ENABLE { + Ok(()) + } else { + Err(ErrorResponse::new( + ErrorResponseType::Internal, + "The atproto client is disabled on this instance", + )) + } +} diff --git a/src/common/src/constants.rs b/src/common/src/constants.rs index 56e07bbd..0df1876b 100644 --- a/src/common/src/constants.rs +++ b/src/common/src/constants.rs @@ -363,6 +363,11 @@ lazy_static! { .parse::() .expect("EXPERIMENTAL_FED_CM_ENABLE cannot be parsed to bool - bad format"); + pub static ref ATPROTO_ENABLE: bool = env::var("ATPROTO_ENABLE") + .unwrap_or_else(|_| String::from("false")) + .parse::() + .expect("ATPROTO_ENABLE cannot be parsed to bool - bad format"); + pub static ref REFRESH_TOKEN_LIFETIME: u16 = env::var("REFRESH_TOKEN_LIFETIME") .unwrap_or_else(|_| String::from("48")) .parse::() From 7b26c4f8c0f6f1b375dec7cd8484b69f668d0a4a Mon Sep 17 00:00:00 2001 From: avdb13 Date: Tue, 17 Dec 2024 18:50:55 +0000 Subject: [PATCH 22/22] separate atproto cache from native caches --- src/models/src/database.rs | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/models/src/database.rs b/src/models/src/database.rs index f628ddd6..43215fd1 100644 --- a/src/models/src/database.rs +++ b/src/models/src/database.rs @@ -36,6 +36,7 @@ struct Migrations; pub enum Cache { App, AuthCode, + Atproto, DeviceCode, AuthProviderCallback, ClientDynamic, @@ -243,13 +244,13 @@ impl Store for DB { type Error = hiqlite::Error; async fn get(&self, key: &String) -> Result, Self::Error> { - Self::client().get(Cache::AuthProviderCallback, key).await + Self::client().get(Cache::Atproto, key).await } async fn set(&self, key: String, value: InternalStateData) -> Result<(), Self::Error> { Self::client() .put( - Cache::AuthProviderCallback, + Cache::Atproto, key, &value, CACHE_TTL_AUTH_PROVIDER_CALLBACK, @@ -258,15 +259,11 @@ impl Store for DB { } async fn del(&self, key: &String) -> Result<(), Self::Error> { - Self::client() - .delete(Cache::AuthProviderCallback, key.to_string()) - .await + Self::client().delete(Cache::Atproto, key.to_string()).await } async fn clear(&self) -> Result<(), Self::Error> { - Self::client() - .clear_cache(Cache::AuthProviderCallback) - .await + Self::client().clear_cache(Cache::Atproto).await } } @@ -275,22 +272,22 @@ impl StateStore for DB {} impl Store for DB { type Error = hiqlite::Error; - async fn get(&self, key: &Did) -> Result, Self::Error> { - Self::client().get(Cache::Session, key.to_string()).await + async fn get(&self, key: &Did) -> Result, Self::Error> { + Self::client().get(Cache::Atproto, key.to_string()).await } - async fn set(&self, key: Did, value: Session) -> Result<(), Self::Error> { + async fn set(&self, key: Did, value: Atproto) -> Result<(), Self::Error> { Self::client() - .put(Cache::Session, key.to_string(), &value, CACHE_TTL_SESSION) + .put(Cache::Atproto, key.to_string(), &value, CACHE_TTL_SESSION) .await } async fn del(&self, key: &Did) -> Result<(), Self::Error> { - Self::client().delete(Cache::Session, key.to_string()).await + Self::client().delete(Cache::Atproto, key.to_string()).await } async fn clear(&self) -> Result<(), Self::Error> { - Self::client().clear_cache(Cache::Session).await + Self::client().clear_cache(Cache::Atproto).await } }