diff --git a/Cargo.toml b/Cargo.toml index cca776ce..7fda3474 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -110,3 +110,10 @@ webauthn-rs = { version = "0.5", features = [ "danger-allow-state-serialisation", "danger-credential-internals" ] } webauthn-rs-proto = "0.5" + +atrium-api = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" } +atrium-common = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" } +atrium-identity = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" } +atrium-oauth-client = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" } + +hickory-resolver = { version = "0.24.1", features = ["tokio", "tokio-rustls", "rustls"] } diff --git a/frontend/src/components/admin/providers/ProviderConfig.svelte b/frontend/src/components/admin/providers/ProviderConfig.svelte index 571faa64..fdc4f17a 100644 --- a/frontend/src/components/admin/providers/ProviderConfig.svelte +++ b/frontend/src/components/admin/providers/ProviderConfig.svelte @@ -58,7 +58,7 @@ token_endpoint: yup.string().url(), userinfo_endpoint: yup.string().url(), - name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128"), + name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128").not(["atproto"], "Cannot be reserved"), client_id: yup.string().trim().matches(REGEX_URI, "Can only contain URI safe characters, length max: 128"), client_secret: yup.string().trim().max(256, "Max 256 characters"), scope: yup.string().trim().matches(REGEX_PROVIDER_SCOPE, "Can only contain: 'a-zA-Z0-9-_/ ', length max: 128"), diff --git a/frontend/src/components/admin/providers/ProviderTileAddNew.svelte b/frontend/src/components/admin/providers/ProviderTileAddNew.svelte index 07877039..fe68f5e6 100644 --- a/frontend/src/components/admin/providers/ProviderTileAddNew.svelte +++ b/frontend/src/components/admin/providers/ProviderTileAddNew.svelte @@ -79,7 +79,7 @@ token_endpoint: yup.string().url().required('Required'), userinfo_endpoint: yup.string().url().required('Required'), - name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128").required('Required'), + name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128").not(["atproto"], "Cannot be reserved").required('Required'), client_id: yup.string().trim().matches(REGEX_URI, "Can only contain URI safe characters, length max: 128").required('Required'), client_secret: yup.string().trim().max(256, "Max 256 characters"), scope: yup.string().trim().matches(REGEX_PROVIDER_SCOPE, "Can only contain: 'a-zA-Z0-9-_/ ', length max: 128").required('Required'), diff --git a/rauthy.cfg b/rauthy.cfg index 250b314c..19314149 100644 --- a/rauthy.cfg +++ b/rauthy.cfg @@ -213,6 +213,14 @@ SUSPICIOUS_REQUESTS_BLACKLIST=1440 # default: false SUSPICIOUS_REQUESTS_LOG=true +##################################### +############## ATPROTO ############## +##################################### + +# Set to `true` to enable the atproto client. +# default: false +# ATPROTO_ENABLE=false + ##################################### ############# BACKUPS ############### ##################################### diff --git a/src/api/src/atproto.rs b/src/api/src/atproto.rs new file mode 100644 index 00000000..95adf85d --- /dev/null +++ b/src/api/src/atproto.rs @@ -0,0 +1,122 @@ +use actix_web::{ + get, + http::header::{self, HeaderValue, LOCATION}, + post, + web::{self, Json, Query}, + HttpRequest, HttpResponse, +}; +use rauthy_api_types::{atproto, auth_providers::ProviderLookupResponse}; +use rauthy_common::constants::ATPROTO_ENABLE; +use rauthy_error::{ErrorResponse, ErrorResponseType}; +use rauthy_models::{ + app_state::AppState, + entity::{atproto::AtprotoCallback, auth_providers::AuthProviderCallback}, +}; + +use crate::{map_auth_step, ReqPrincipal}; + +#[utoipa::path( + get, + path = "/atproto/client_metadata", + tag = "atproto", + responses( + (status = 200, description = "OK"), + ), +)] +#[get("/atproto/client_metadata")] +pub async fn get_client_metadata(data: web::Data) -> Result { + is_atproto_enabled()?; + + Ok(HttpResponse::Ok() + .insert_header(( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_str("*").unwrap(), + )) + .json(&data.atproto.client_metadata)) +} + +#[utoipa::path( + post, + path = "/atproto/login", + tag = "atproto", + responses( + (status = 202, description = "Accepted"), + (status = 400, description = "BadRequest", body = ErrorResponse), + (status = 404, description = "NotFound", body = ErrorResponse), + ), +)] +#[post("/atproto/login")] +pub async fn post_login( + data: web::Data, + payload: Json, + principal: ReqPrincipal, +) -> Result { + is_atproto_enabled()?; + principal.validate_session_auth_or_init()?; + + let payload = payload.into_inner(); + let (cookie, xsrf_token, location) = + ::login_start(&data, payload).await?; + + Ok(HttpResponse::Accepted() + .insert_header((LOCATION, location)) + .cookie(cookie) + .body(xsrf_token)) +} + +#[utoipa::path( + post, + path = "/atproto/callback", + tag = "atproto", + responses( + (status = 200, description = "OK", body = ProviderLookupResponse), + (status = 400, description = "BadRequest", body = ErrorResponse), + (status = 404, description = "NotFound", body = ErrorResponse), + ), +)] +#[post("/atproto/callback")] +#[tracing::instrument( + name = "post_provider_callback", + skip_all, fields(callback_id = payload.state) +)] +pub async fn post_callback( + data: web::Data, + req: HttpRequest, + payload: Query, + principal: ReqPrincipal, +) -> Result { + is_atproto_enabled()?; + principal.validate_session_auth_or_init()?; + + let payload = payload.into_inner(); + let session = principal.get_session()?; + let (auth_step, cookie) = ::login_finish( + &data, + &req, + &payload, + session.clone(), + ) + .await?; + + let mut resp = map_auth_step(auth_step, &req).await?; + resp.add_cookie(&cookie).map_err(|err| { + ErrorResponse::new( + ErrorResponseType::Internal, + format!("Error adding cookie after map_auth_step: {}", err), + ) + })?; + + Ok(resp) +} + +#[inline(always)] +fn is_atproto_enabled() -> Result<(), ErrorResponse> { + if *ATPROTO_ENABLE { + Ok(()) + } else { + Err(ErrorResponse::new( + ErrorResponseType::Internal, + "The atproto client is disabled on this instance", + )) + } +} diff --git a/src/api/src/auth_providers.rs b/src/api/src/auth_providers.rs index aa1f257e..20d15ee3 100644 --- a/src/api/src/auth_providers.rs +++ b/src/api/src/auth_providers.rs @@ -68,6 +68,13 @@ pub async fn post_provider( ) -> Result { principal.validate_admin_session()?; + if payload.name == "atproto" { + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "Must not contain a reserved name".to_string(), + )); + } + if !payload.use_pkce && payload.client_secret.is_none() { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, @@ -261,6 +268,13 @@ pub async fn put_provider( ) -> Result { principal.validate_admin_session()?; + if payload.name == "atproto" { + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "Must not contain a reserved name".to_string(), + )); + } + if !payload.use_pkce && payload.client_secret.is_none() { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, diff --git a/src/api/src/lib.rs b/src/api/src/lib.rs index 07aff120..5a1ceccf 100644 --- a/src/api/src/lib.rs +++ b/src/api/src/lib.rs @@ -17,6 +17,7 @@ use rust_embed::RustEmbed; use tracing::error; pub mod api_keys; +pub mod atproto; pub mod auth_providers; pub mod blacklist; pub mod clients; diff --git a/src/api/src/openapi.rs b/src/api/src/openapi.rs index 1dc921e6..ecb48271 100644 --- a/src/api/src/openapi.rs +++ b/src/api/src/openapi.rs @@ -1,8 +1,11 @@ use crate::{ - api_keys, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc, roles, - scopes, sessions, users, + api_keys, atproto, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc, + roles, scopes, sessions, users, }; use actix_web::web; +use rauthy_api_types::atproto::{ + CallbackRequest as AtprotoCallbackRequest, LoginRequest as AtprotoLoginRequest, +}; use rauthy_api_types::{ api_keys::*, auth_providers::*, blacklist::*, clients::*, events::*, fed_cm::*, generic::*, groups::*, oidc::*, roles::*, scopes::*, sessions::*, users::*, @@ -26,6 +29,10 @@ use utoipa::{openapi, OpenApi}; api_keys::get_api_key_test, api_keys::put_api_key_secret, + atproto::get_client_metadata, + atproto::post_login, + atproto::post_callback, + auth_providers::post_providers, auth_providers::post_provider, auth_providers::post_provider_lookup, @@ -152,6 +159,7 @@ use utoipa::{openapi, OpenApi}; components( schemas( entity::colors::Colors, + entity::atproto::DnsTxtResolver, entity::fed_cm::FedCMAccount, entity::fed_cm::FedCMAccounts, entity::fed_cm::FedCMIdPBranding, @@ -186,6 +194,8 @@ use utoipa::{openapi, OpenApi}; ErrorResponseType, ApiKeyRequest, + AtprotoCallbackRequest, + AtprotoLoginRequest, AuthCodeRequest, AuthRequest, IpBlacklistRequest, @@ -294,6 +304,7 @@ use utoipa::{openapi, OpenApi}; (name = "generic", description = "Generic endpoints"), (name = "webid", description = "WebID endpoints"), (name = "fed_cm", description = "Experimental FedCM endpoints"), + (name = "atproto", description = "Experimental ATProto endpoints"), (name = "deprecated", description = "Deprecated endpoints - will be removed in a future version"), ), )] diff --git a/src/api_types/src/atproto.rs b/src/api_types/src/atproto.rs new file mode 100644 index 00000000..5faf311a --- /dev/null +++ b/src/api_types/src/atproto.rs @@ -0,0 +1,40 @@ +use rauthy_common::constants::{RE_ALNUM, RE_URI}; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; +use validator::Validate; + +#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)] +pub struct LoginRequest { + /// Validation: + /// `^(did:[a-z]+:[a-zA-Z0-9._:%-]*[a-zA-Z0-9._-]|([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)$` + pub at_id: String, + /// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$` + #[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))] + pub redirect_uri: String, + /// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$` + #[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))] + pub state: Option, + + /// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$` + #[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))] + pub pkce_challenge: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)] +pub struct CallbackRequest { + /// Validation: `[a-zA-Z0-9]` + #[validate(regex(path = "*RE_ALNUM", code = "[a-zA-Z0-9]"))] + pub state: String, + /// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$` + #[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))] + pub code: String, + /// Validation: `[a-zA-Z0-9]` + #[validate(regex(path = "*RE_ALNUM", code = "[a-zA-Z0-9]"))] + pub iss: Option, + /// Validation: `[a-zA-Z0-9]` + #[validate(regex(path = "*RE_ALNUM", code = "[a-zA-Z0-9]"))] + pub xsrf_token: String, + /// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$` + #[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))] + pub pkce_verifier: String, +} diff --git a/src/api_types/src/lib.rs b/src/api_types/src/lib.rs index 2bd6b09e..58b8ef45 100644 --- a/src/api_types/src/lib.rs +++ b/src/api_types/src/lib.rs @@ -1,4 +1,5 @@ pub mod api_keys; +pub mod atproto; pub mod auth_providers; pub mod blacklist; pub mod clients; diff --git a/src/bin/src/main.rs b/src/bin/src/main.rs index 6828eb64..3e5ac03a 100644 --- a/src/bin/src/main.rs +++ b/src/bin/src/main.rs @@ -16,8 +16,8 @@ use rauthy_common::utils::UseDummyAddress; use rauthy_common::{is_hiqlite, is_sqlite, password_hasher}; use rauthy_handlers::openapi::ApiDoc; use rauthy_handlers::{ - api_keys, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc, roles, - scopes, sessions, users, + api_keys, atproto, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc, + roles, scopes, sessions, users, }; use rauthy_middlewares::csrf_protection::CsrfProtectionMiddleware; use rauthy_middlewares::ip_blacklist::RauthyIpBlacklistMiddleware; @@ -556,7 +556,10 @@ async fn actix_main(app_state: web::Data) -> std::io::Result<()> { .service(oidc::get_well_known) .service(generic::get_health) .service(generic::get_ready) - .service(generic::get_static_assets), + .service(generic::get_static_assets) + .service(atproto::get_client_metadata) + .service(atproto::post_login) + .service(atproto::post_callback), ), ); diff --git a/src/common/src/constants.rs b/src/common/src/constants.rs index 56e07bbd..0df1876b 100644 --- a/src/common/src/constants.rs +++ b/src/common/src/constants.rs @@ -363,6 +363,11 @@ lazy_static! { .parse::() .expect("EXPERIMENTAL_FED_CM_ENABLE cannot be parsed to bool - bad format"); + pub static ref ATPROTO_ENABLE: bool = env::var("ATPROTO_ENABLE") + .unwrap_or_else(|_| String::from("false")) + .parse::() + .expect("ATPROTO_ENABLE cannot be parsed to bool - bad format"); + pub static ref REFRESH_TOKEN_LIFETIME: u16 = env::var("REFRESH_TOKEN_LIFETIME") .unwrap_or_else(|_| String::from("48")) .parse::() diff --git a/src/models/Cargo.toml b/src/models/Cargo.toml index f80e2f9e..45fc004e 100644 --- a/src/models/Cargo.toml +++ b/src/models/Cargo.toml @@ -81,6 +81,13 @@ validator = { workspace = true } webauthn-rs = { workspace = true } webauthn-rs-proto = { workspace = true } +atrium-api = { workspace = true } +atrium-common = { workspace = true } +atrium-identity = { workspace = true } +atrium-oauth-client = { workspace = true } + +hickory-resolver = { workspace = true } + [dev-dependencies] pretty_assertions = "1" rstest = "0.18.2" diff --git a/src/models/src/app_state.rs b/src/models/src/app_state.rs index 0e83c9ea..8df1354e 100644 --- a/src/models/src/app_state.rs +++ b/src/models/src/app_state.rs @@ -1,4 +1,5 @@ use crate::email::EMail; +use crate::entity::atproto::AtprotoClient; use crate::events::event::Event; use crate::events::ip_blacklist_handler::IpBlacklistReq; use crate::events::listener::EventRouterMsg; @@ -31,6 +32,7 @@ pub struct AppState { pub tx_events_router: flume::Sender, pub tx_ip_blacklist: flume::Sender, pub webauthn: Arc, + pub atproto: AtprotoClient, } impl AppState { @@ -154,6 +156,9 @@ impl AppState { .rp_name(&rp_name); let webauthn = Arc::new(builder.build().expect("Invalid configuration")); + let atproto = AtprotoClient::new(&listen_scheme, &public_url) + .expect("failed to initialize atproto client"); + Ok(Self { public_url, argon2_params, @@ -170,6 +175,7 @@ impl AppState { tx_events_router, tx_ip_blacklist, webauthn, + atproto, }) } diff --git a/src/models/src/database.rs b/src/models/src/database.rs index 01ad66c5..43215fd1 100644 --- a/src/models/src/database.rs +++ b/src/models/src/database.rs @@ -3,8 +3,14 @@ use crate::entity::db_version::DbVersion; use crate::migration::db_migrate_dev::migrate_dev_data; use crate::migration::{anti_lockout, db_migrate, init_prod}; use actix_web::web; +use atrium_api::types::string::Did; +use atrium_common::store::Store; +use atrium_oauth_client::store::session::{Session, SessionStore}; +use atrium_oauth_client::store::state::{InternalStateData, StateStore}; use hiqlite::NodeConfig; -use rauthy_common::constants::{DATABASE_URL, DEV_MODE}; +use rauthy_common::constants::{ + CACHE_TTL_AUTH_PROVIDER_CALLBACK, CACHE_TTL_SESSION, DATABASE_URL, DEV_MODE, +}; use rauthy_common::{is_hiqlite, is_postgres}; use rauthy_error::{ErrorResponse, ErrorResponseType}; use serde::{Deserialize, Serialize}; @@ -30,6 +36,7 @@ struct Migrations; pub enum Cache { App, AuthCode, + Atproto, DeviceCode, AuthProviderCallback, ClientDynamic, @@ -232,3 +239,56 @@ impl DB { Ok(()) } } + +impl Store for DB { + type Error = hiqlite::Error; + + async fn get(&self, key: &String) -> Result, Self::Error> { + Self::client().get(Cache::Atproto, key).await + } + + async fn set(&self, key: String, value: InternalStateData) -> Result<(), Self::Error> { + Self::client() + .put( + Cache::Atproto, + key, + &value, + CACHE_TTL_AUTH_PROVIDER_CALLBACK, + ) + .await + } + + async fn del(&self, key: &String) -> Result<(), Self::Error> { + Self::client().delete(Cache::Atproto, key.to_string()).await + } + + async fn clear(&self) -> Result<(), Self::Error> { + Self::client().clear_cache(Cache::Atproto).await + } +} + +impl StateStore for DB {} + +impl Store for DB { + type Error = hiqlite::Error; + + async fn get(&self, key: &Did) -> Result, Self::Error> { + Self::client().get(Cache::Atproto, key.to_string()).await + } + + async fn set(&self, key: Did, value: Atproto) -> Result<(), Self::Error> { + Self::client() + .put(Cache::Atproto, key.to_string(), &value, CACHE_TTL_SESSION) + .await + } + + async fn del(&self, key: &Did) -> Result<(), Self::Error> { + Self::client().delete(Cache::Atproto, key.to_string()).await + } + + async fn clear(&self) -> Result<(), Self::Error> { + Self::client().clear_cache(Cache::Atproto).await + } +} + +impl SessionStore for DB {} diff --git a/src/models/src/entity/atproto.rs b/src/models/src/entity/atproto.rs new file mode 100644 index 00000000..394b320b --- /dev/null +++ b/src/models/src/entity/atproto.rs @@ -0,0 +1,452 @@ +use std::{future::Future, ops::Deref, sync::Arc}; + +use actix_web::{ + cookie::Cookie, + http::header::{self, HeaderValue}, + web, HttpRequest, +}; +use atrium_api::{agent::Agent, com::atproto::server::get_session, types::Object}; +use atrium_identity::{ + did::{CommonDidResolver, CommonDidResolverConfig, DEFAULT_PLC_DIRECTORY_URL}, + handle::{ + AtprotoHandleResolver, AtprotoHandleResolverConfig, DnsTxtResolver as DnsTxtResolverTrait, + }, +}; +use atrium_oauth_client::{ + AtprotoClientMetadata, AuthMethod, AuthorizeOptions, CallbackParams, DefaultHttpClient, + GrantType, KnownScope, OAuthClient, OAuthClientConfig, OAuthResolverConfig, Scope, +}; +use cryptr::utils::secure_random_alnum; +use hickory_resolver::{proto::rr::rdata::TXT, TokioAsyncResolver}; +use rauthy_api_types::atproto; +use rauthy_common::constants::{ + COOKIE_UPSTREAM_CALLBACK, PROVIDER_LINK_COOKIE, UPSTREAM_AUTH_CALLBACK_TIMEOUT_SECS, +}; +use rauthy_error::{ErrorResponse, ErrorResponseType}; +use tracing::{debug, error}; +use utoipa::{PartialSchema, ToSchema}; + +use crate::{ + api_cookie::ApiCookie, + app_state::AppState, + database::DB, + entity::{ + auth_codes::AuthCode, auth_providers::AuthProviderLinkCookie, clients::Client, users::User, + }, + AuthStep, AuthStepLoggedIn, ListenScheme, +}; + +use super::{ + auth_providers::{AuthProviderCallback, AuthProviderType}, + sessions::Session, +}; + +#[derive(Clone)] +pub struct AtprotoClient( + Arc< + OAuthClient< + DB, + DB, + CommonDidResolver, + AtprotoHandleResolver, + >, + >, +); + +impl std::fmt::Debug for AtprotoClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("AtprotoClient").finish() + } +} + +impl AtprotoClient { + pub fn new( + listen_scheme: &ListenScheme, + public_url: &str, + ) -> Result { + let http_client = Arc::new(DefaultHttpClient::default()); + + let listen_scheme = match listen_scheme { + ListenScheme::Http | ListenScheme::UnixHttp => "http", + ListenScheme::Https | ListenScheme::HttpHttps | ListenScheme::UnixHttps => "https", + }; + + let client_metadata = AtprotoClientMetadata { + client_id: format!("{listen_scheme}://{public_url}/auth/v1/atproto/client_metadata"), + client_uri: String::new(), + redirect_uris: vec![format!( + "{listen_scheme}://{public_url}/auth/v1/atproto/callback" + )], + token_endpoint_auth_method: AuthMethod::None, + grant_types: vec![GrantType::AuthorizationCode], + scopes: vec![Scope::Known(KnownScope::Atproto)], + jwks_uri: None, + token_endpoint_auth_signing_alg: None, + }; + + let config = OAuthClientConfig { + client_metadata: client_metadata.clone(), + keys: None, + resolver: OAuthResolverConfig { + did_resolver: CommonDidResolver::new(CommonDidResolverConfig { + plc_directory_url: DEFAULT_PLC_DIRECTORY_URL.to_string(), + http_client: http_client.clone(), + }), + handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { + dns_txt_resolver: DnsTxtResolver::default(), + http_client: http_client.clone(), + }), + authorization_server_metadata: Default::default(), + protected_resource_metadata: Default::default(), + }, + state_store: DB, + session_store: DB, + }; + + OAuthClient::new(config).map(Arc::new).map(AtprotoClient) + } +} + +impl Deref for AtprotoClient { + type Target = OAuthClient< + DB, + DB, + CommonDidResolver, + AtprotoHandleResolver, + >; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub struct DnsTxtResolver { + resolver: TokioAsyncResolver, +} + +impl Default for DnsTxtResolver { + fn default() -> Self { + Self { + resolver: TokioAsyncResolver::tokio_from_system_conf() + .expect("failed to create resolver"), + } + } +} + +impl ToSchema for DnsTxtResolver {} + +impl PartialSchema for DnsTxtResolver { + fn schema() -> utoipa::openapi::RefOr { + utoipa::openapi::RefOr::T(utoipa::openapi::Schema::Object( + utoipa::openapi::ObjectBuilder::new().build(), + )) + } +} + +impl DnsTxtResolverTrait for DnsTxtResolver { + async fn resolve( + &self, + query: &str, + ) -> Result, Box> { + let txt_lookup = self.resolver.txt_lookup(query).await?; + Ok(txt_lookup.iter().map(TXT::to_string).collect()) + } +} + +pub trait AtprotoCallback { + fn login_start<'a>( + data: &'a web::Data, + payload: atproto::LoginRequest, + ) -> impl Future, String, HeaderValue), ErrorResponse>>; + + fn login_finish<'a>( + data: &'a web::Data, + req: &'a HttpRequest, + payload: &'a atproto::CallbackRequest, + session: Session, + ) -> impl Future), ErrorResponse>>; +} + +impl AtprotoCallback for AuthProviderCallback { + async fn login_start<'a>( + data: &'a web::Data, + payload: atproto::LoginRequest, + ) -> Result<(Cookie<'a>, String, HeaderValue), ErrorResponse> { + let slf = Self { + callback_id: secure_random_alnum(32), + xsrf_token: secure_random_alnum(32), + typ: AuthProviderType::Custom, + + req_client_id: String::from("atproto"), + req_scopes: None, + req_redirect_uri: payload.redirect_uri.clone(), + req_state: payload.state.clone(), + req_nonce: None, + req_code_challenge: None, + req_code_challenge_method: None, + + provider_id: String::from("atproto"), + + pkce_challenge: payload.pkce_challenge, + }; + + slf.save().await?; + + let options = AuthorizeOptions { + state: Some(slf.callback_id.clone()), + ..Default::default() + }; + + match data.atproto.authorize(&payload.at_id, options).await { + Ok(location) => { + let cookie = ApiCookie::build( + COOKIE_UPSTREAM_CALLBACK, + &slf.callback_id, + UPSTREAM_AUTH_CALLBACK_TIMEOUT_SECS as i64, + ); + let header = + HeaderValue::from_str(&location).expect("Location HeaderValue to be correct"); + + Ok((cookie, slf.xsrf_token, header)) + } + Err(error) => { + error!(%error, "failed to build pushed authorization request for atproto"); + + Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "failed to build pushed authorization request for atproto", + )) + } + } + } + + async fn login_finish<'a>( + data: &'a web::Data, + req: &'a HttpRequest, + payload: &'a atproto::CallbackRequest, + session: Session, + ) -> Result<(AuthStep, Cookie<'a>), ErrorResponse> { + // the callback id for the cache should be inside the encrypted cookie + let callback_id = ApiCookie::from_req(req, COOKIE_UPSTREAM_CALLBACK).ok_or_else(|| { + ErrorResponse::new( + ErrorResponseType::Forbidden, + "Missing encrypted callback cookie", + ) + })?; + + // validate state + if callback_id != payload.state { + AuthProviderCallback::delete(callback_id).await?; + + error!("`state` does not match"); + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "`state` does not match", + )); + } + debug!("callback state is valid"); + + // validate csrf token + let slf = AuthProviderCallback::find(callback_id).await?; + if slf.xsrf_token != payload.xsrf_token { + AuthProviderCallback::delete(slf.callback_id).await?; + + error!("invalid CSRF token"); + return Err(ErrorResponse::new( + ErrorResponseType::Unauthorized, + "invalid CSRF token", + )); + } + debug!("callback csrf token is valid"); + + // request is valid -> fetch token for the user + let params = CallbackParams { + code: payload.code.clone(), + state: Some(payload.state.clone()), + iss: payload.iss.clone(), + }; + // return early if we got any error + let (session_manager, _) = data.atproto.callback(params).await.map_err(|error| { + error!(%error, "failed to complete authorization callback for atproto"); + + ErrorResponse::new( + ErrorResponseType::BadRequest, + "failed to complete authorization callback for atproto", + ) + })?; + + let agent = Agent::new(session_manager); + + let (did, email, email_confirmed) = match agent.api.com.atproto.server.get_session().await { + Ok(Object { + data: + get_session::OutputData { + did, + email, + email_confirmed, + .. + }, + .. + }) => (did, email, email_confirmed), + Err(error) => { + error!(%error, "failed to get session for atproto"); + + return Err(ErrorResponse::new( + ErrorResponseType::Internal, + "failed to get session for atproto", + )); + } + }; + + let link_cookie = ApiCookie::from_req(req, PROVIDER_LINK_COOKIE) + .and_then(|value| AuthProviderLinkCookie::try_from(value.as_str()).ok()); + + if email.is_none() { + todo!() + } + + let claims_user_id = did.to_string(); + + let user_opt = match User::find_by_federation("atproto", &claims_user_id).await { + Ok(user) => { + debug!( + "found already existing user by federation lookup: {:?}", + user + ); + Some(user) + } + Err(_) => { + debug!("did not find already existing user by federation lookup - making sure email does not exist"); + + if let Ok(mut user) = User::find_by_email(email.as_ref().unwrap().to_string()).await + { + if let Some(link) = link_cookie.as_ref() { + if link.provider_id != "atproto" { + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "bad provider_id in link cookie", + )); + } + + if link.user_id != user.id { + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "bad user_id in link cookie", + )); + } + + if link.user_email != user.email { + return Err(ErrorResponse::new( + ErrorResponseType::BadRequest, + "Invalid E-Mail", + )); + } + + user.auth_provider_id = Some("atproto".to_owned()); + user.federation_uid = Some(claims_user_id.clone()); + + Some(user) + } else { + return Err(ErrorResponse::new(ErrorResponseType::Forbidden, format!( + "User with email '{}' already exists but is not linked to this provider.", + user.email + ))); + } + } else { + None + } + } + }; + debug!("user_opt:\n{:?}", user_opt); + + let now = chrono::Utc::now().timestamp(); + let user = if let Some(mut user) = user_opt { + let mut old_email = None; + let mut forbidden_error = None; + + if user.federation_uid.is_none() + || user.federation_uid.as_deref() != Some(&claims_user_id) + { + forbidden_error = Some("non-federated user or ID mismatch"); + } + + if user.auth_provider_id.as_deref() != Some("atproto") { + forbidden_error = Some("invalid login from wrong auth provider"); + } + + if let Some(err) = forbidden_error { + user.last_failed_login = Some(now); + user.failed_login_attempts = + Some(user.failed_login_attempts.unwrap_or_default() + 1); + user.save(old_email).await?; + + return Err(ErrorResponse::new(ErrorResponseType::Forbidden, err)); + } + + if Some(user.email.as_str()) != email.as_deref() { + old_email = Some(user.email); + user.email = email.as_ref().unwrap().to_string(); + } + + user.last_login = Some(now); + user.last_failed_login = None; + user.failed_login_attempts = None; + + user.save(old_email).await?; + user + } else { + let new_user = User { + email: email.as_ref().unwrap().to_string(), + given_name: "Unknown".to_string(), + family_name: None, + roles: Default::default(), + enabled: true, + email_verified: email_confirmed.unwrap_or(false), + last_login: Some(now), + language: Default::default(), + auth_provider_id: Some("atproto".to_owned()), + federation_uid: Some(claims_user_id.to_string()), + ..Default::default() + }; + User::create_federated(new_user).await? + }; + + user.check_enabled()?; + user.check_expired()?; + + if link_cookie.is_some() { + return Ok(( + AuthStep::ProviderLink, + AuthProviderLinkCookie::deletion_cookie(), + )); + } + + let client = Client::default(); + + let code = AuthCode::new( + user.id.clone(), + "atproto".to_owned(), + Some(session.id.clone()), + slf.req_code_challenge, + slf.req_code_challenge_method, + slf.req_nonce, + vec!["atproto".to_owned()], + client.auth_code_lifetime, + ); + code.save().await?; + + let auth_step = AuthStep::LoggedIn(AuthStepLoggedIn { + user_id: user.id, + email: user.email, + header_loc: ( + header::LOCATION, + HeaderValue::from_str(&(format!("{}?code={}", slf.req_redirect_uri, code.id)))?, + ), + header_csrf: Session::get_csrf_header(&session.csrf_token), + header_origin: None, + }); + + let cookie = ApiCookie::build(COOKIE_UPSTREAM_CALLBACK, "", 0); + Ok((auth_step, cookie)) + } +} diff --git a/src/models/src/entity/auth_providers.rs b/src/models/src/entity/auth_providers.rs index 636296c0..0fbfc368 100644 --- a/src/models/src/entity/auth_providers.rs +++ b/src/models/src/entity/auth_providers.rs @@ -48,7 +48,6 @@ use std::borrow::Cow; use std::fmt::Write; use std::str::FromStr; use std::time::Duration; -use time::OffsetDateTime; use tracing::{debug, error}; use utoipa::ToSchema; @@ -729,7 +728,7 @@ impl AuthProviderCallback { Ok(()) } - async fn find(callback_id: String) -> Result { + pub async fn find(callback_id: String) -> Result { let opt: Option = DB::client() .get(Cache::AuthProviderCallback, callback_id) .await?; @@ -743,7 +742,7 @@ impl AuthProviderCallback { } } - async fn save(&self) -> Result<(), ErrorResponse> { + pub async fn save(&self) -> Result<(), ErrorResponse> { DB::client() .put( Cache::AuthProviderCallback, @@ -1205,14 +1204,11 @@ impl AuthProviderIdClaims<'_> { let _header = parts.next().ok_or_else(|| { ErrorResponse::new( ErrorResponseType::BadRequest, - "incorrect ID did not contain claims".to_string(), + "incorrect ID did not contain claims", ) })?; let claims = parts.next().ok_or_else(|| { - ErrorResponse::new( - ErrorResponseType::BadRequest, - "ID token was unsigned".to_string(), - ) + ErrorResponse::new(ErrorResponseType::BadRequest, "ID token was unsigned") })?; debug!("upstream ID token claims:\n{}", claims); let json_bytes = base64_url_no_pad_decode(claims)?; @@ -1227,10 +1223,7 @@ impl AuthProviderIdClaims<'_> { if self.email.is_none() { let err = "No `email` in ID token claims. This is a mandatory claim"; error!("{}", err); - return Err(ErrorResponse::new( - ErrorResponseType::BadRequest, - err.to_string(), - )); + return Err(ErrorResponse::new(ErrorResponseType::BadRequest, err)); } let claims_user_id = if let Some(sub) = &self.sub { @@ -1242,10 +1235,7 @@ impl AuthProviderIdClaims<'_> { } else { let err = "Cannot find any user id in the response"; error!("{}", err); - return Err(ErrorResponse::new( - ErrorResponseType::BadRequest, - err.to_string(), - )); + return Err(ErrorResponse::new(ErrorResponseType::BadRequest, err)); } // We need to create a real string here, since we don't know what json type we get. // Any json number would become a String too, which is what we need for compatibility. @@ -1274,7 +1264,7 @@ impl AuthProviderIdClaims<'_> { if link.provider_id != provider.id { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, - "bad provider_id in link cookie".to_string(), + "bad provider_id in link cookie", )); } @@ -1284,7 +1274,7 @@ impl AuthProviderIdClaims<'_> { // multiple accounts. return Err(ErrorResponse::new( ErrorResponseType::BadRequest, - "bad user_id in link cookie".to_string(), + "bad user_id in link cookie", )); } @@ -1292,7 +1282,7 @@ impl AuthProviderIdClaims<'_> { if link.user_email != user.email { return Err(ErrorResponse::new( ErrorResponseType::BadRequest, - "Invalid E-Mail".to_string(), + "Invalid E-Mail", )); } @@ -1321,7 +1311,7 @@ impl AuthProviderIdClaims<'_> { if provider.admin_claim_value.is_none() { return Err(ErrorResponse::new( ErrorResponseType::Internal, - "Misconfigured Auth Provider - admin claim path without value".to_string(), + "Misconfigured Auth Provider - admin claim path without value", )); } @@ -1362,7 +1352,7 @@ impl AuthProviderIdClaims<'_> { if provider.mfa_claim_value.is_none() { return Err(ErrorResponse::new( ErrorResponseType::Internal, - "Misconfigured Auth Provider - mfa claim path without value".to_string(), + "Misconfigured Auth Provider - mfa claim path without value", )); } @@ -1397,7 +1387,7 @@ impl AuthProviderIdClaims<'_> { } } - let now = OffsetDateTime::now_utc().unix_timestamp(); + let now = chrono::Utc::now().timestamp(); let user = if let Some(mut user) = user_opt { let mut old_email = None; let mut forbidden_error = None; @@ -1422,10 +1412,7 @@ impl AuthProviderIdClaims<'_> { Some(user.failed_login_attempts.unwrap_or_default() + 1); user.save(old_email).await?; - return Err(ErrorResponse::new( - ErrorResponseType::Forbidden, - err.to_string(), - )); + return Err(ErrorResponse::new(ErrorResponseType::Forbidden, err)); } // check / update email diff --git a/src/models/src/entity/mod.rs b/src/models/src/entity/mod.rs index 8e100e7c..33eea604 100644 --- a/src/models/src/entity/mod.rs +++ b/src/models/src/entity/mod.rs @@ -5,6 +5,7 @@ use sqlx::query; pub mod api_keys; pub mod app_version; +pub mod atproto; pub mod auth_codes; pub mod auth_providers; pub mod clients;