Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ATProto support #644

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,10 @@ webauthn-rs = { version = "0.5", features = [
"danger-allow-state-serialisation", "danger-credential-internals"
] }
webauthn-rs-proto = "0.5"

atrium-api = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" }
atrium-common = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" }
atrium-identity = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" }
atrium-oauth-client = { git = "https://github.com/avdb13/atrium", branch = "oauth-session" }

hickory-resolver = { version = "0.24.1", features = ["tokio", "tokio-rustls", "rustls"] }
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
token_endpoint: yup.string().url(),
userinfo_endpoint: yup.string().url(),

name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128"),
name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128").not(["atproto"], "Cannot be reserved"),
client_id: yup.string().trim().matches(REGEX_URI, "Can only contain URI safe characters, length max: 128"),
client_secret: yup.string().trim().max(256, "Max 256 characters"),
scope: yup.string().trim().matches(REGEX_PROVIDER_SCOPE, "Can only contain: 'a-zA-Z0-9-_/ ', length max: 128"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
token_endpoint: yup.string().url().required('Required'),
userinfo_endpoint: yup.string().url().required('Required'),

name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128").required('Required'),
name: yup.string().trim().matches(REGEX_CLIENT_NAME, "Can only contain: 'a-zA-Z0-9À-ÿ- ', length max: 128").not(["atproto"], "Cannot be reserved").required('Required'),
client_id: yup.string().trim().matches(REGEX_URI, "Can only contain URI safe characters, length max: 128").required('Required'),
client_secret: yup.string().trim().max(256, "Max 256 characters"),
scope: yup.string().trim().matches(REGEX_PROVIDER_SCOPE, "Can only contain: 'a-zA-Z0-9-_/ ', length max: 128").required('Required'),
Expand Down
8 changes: 8 additions & 0 deletions rauthy.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@ SUSPICIOUS_REQUESTS_BLACKLIST=1440
# default: false
SUSPICIOUS_REQUESTS_LOG=true

#####################################
############## ATPROTO ##############
#####################################

# Set to `true` to enable the atproto client.
# default: false
# ATPROTO_ENABLE=false

#####################################
############# BACKUPS ###############
#####################################
Expand Down
122 changes: 122 additions & 0 deletions src/api/src/atproto.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use actix_web::{
get,
http::header::{self, HeaderValue, LOCATION},
post,
web::{self, Json, Query},
HttpRequest, HttpResponse,
};
use rauthy_api_types::{atproto, auth_providers::ProviderLookupResponse};
use rauthy_common::constants::ATPROTO_ENABLE;
use rauthy_error::{ErrorResponse, ErrorResponseType};
use rauthy_models::{
app_state::AppState,
entity::{atproto::AtprotoCallback, auth_providers::AuthProviderCallback},
};

use crate::{map_auth_step, ReqPrincipal};

#[utoipa::path(
get,
path = "/atproto/client_metadata",
tag = "atproto",
responses(
(status = 200, description = "OK"),
),
)]
#[get("/atproto/client_metadata")]
pub async fn get_client_metadata(data: web::Data<AppState>) -> Result<HttpResponse, ErrorResponse> {
is_atproto_enabled()?;

Ok(HttpResponse::Ok()
.insert_header((
header::ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_str("*").unwrap(),
))
.json(&data.atproto.client_metadata))
}

#[utoipa::path(
post,
path = "/atproto/login",
tag = "atproto",
responses(
(status = 202, description = "Accepted"),
(status = 400, description = "BadRequest", body = ErrorResponse),
(status = 404, description = "NotFound", body = ErrorResponse),
),
)]
#[post("/atproto/login")]
pub async fn post_login(
data: web::Data<AppState>,
payload: Json<atproto::LoginRequest>,
principal: ReqPrincipal,
) -> Result<HttpResponse, ErrorResponse> {
is_atproto_enabled()?;
principal.validate_session_auth_or_init()?;

let payload = payload.into_inner();
let (cookie, xsrf_token, location) =
<AuthProviderCallback as AtprotoCallback>::login_start(&data, payload).await?;

Ok(HttpResponse::Accepted()
.insert_header((LOCATION, location))
.cookie(cookie)
.body(xsrf_token))
}

#[utoipa::path(
post,
path = "/atproto/callback",
tag = "atproto",
responses(
(status = 200, description = "OK", body = ProviderLookupResponse),
(status = 400, description = "BadRequest", body = ErrorResponse),
(status = 404, description = "NotFound", body = ErrorResponse),
),
)]
#[post("/atproto/callback")]
#[tracing::instrument(
name = "post_provider_callback",
skip_all, fields(callback_id = payload.state)
)]
pub async fn post_callback(
data: web::Data<AppState>,
req: HttpRequest,
payload: Query<atproto::CallbackRequest>,
principal: ReqPrincipal,
) -> Result<HttpResponse, ErrorResponse> {
is_atproto_enabled()?;
principal.validate_session_auth_or_init()?;

let payload = payload.into_inner();
let session = principal.get_session()?;
let (auth_step, cookie) = <AuthProviderCallback as AtprotoCallback>::login_finish(
&data,
&req,
&payload,
session.clone(),
)
.await?;

let mut resp = map_auth_step(auth_step, &req).await?;
resp.add_cookie(&cookie).map_err(|err| {
ErrorResponse::new(
ErrorResponseType::Internal,
format!("Error adding cookie after map_auth_step: {}", err),
)
})?;

Ok(resp)
}

#[inline(always)]
fn is_atproto_enabled() -> Result<(), ErrorResponse> {
if *ATPROTO_ENABLE {
Ok(())
} else {
Err(ErrorResponse::new(
ErrorResponseType::Internal,
"The atproto client is disabled on this instance",
))
}
}
14 changes: 14 additions & 0 deletions src/api/src/auth_providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ pub async fn post_provider(
) -> Result<HttpResponse, ErrorResponse> {
principal.validate_admin_session()?;

if payload.name == "atproto" {
return Err(ErrorResponse::new(
ErrorResponseType::BadRequest,
"Must not contain a reserved name".to_string(),
));
}

if !payload.use_pkce && payload.client_secret.is_none() {
return Err(ErrorResponse::new(
ErrorResponseType::BadRequest,
Expand Down Expand Up @@ -261,6 +268,13 @@ pub async fn put_provider(
) -> Result<HttpResponse, ErrorResponse> {
principal.validate_admin_session()?;

if payload.name == "atproto" {
return Err(ErrorResponse::new(
ErrorResponseType::BadRequest,
"Must not contain a reserved name".to_string(),
));
}

if !payload.use_pkce && payload.client_secret.is_none() {
return Err(ErrorResponse::new(
ErrorResponseType::BadRequest,
Expand Down
1 change: 1 addition & 0 deletions src/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use rust_embed::RustEmbed;
use tracing::error;

pub mod api_keys;
pub mod atproto;
pub mod auth_providers;
pub mod blacklist;
pub mod clients;
Expand Down
15 changes: 13 additions & 2 deletions src/api/src/openapi.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use crate::{
api_keys, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc, roles,
scopes, sessions, users,
api_keys, atproto, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc,
roles, scopes, sessions, users,
};
use actix_web::web;
use rauthy_api_types::atproto::{
CallbackRequest as AtprotoCallbackRequest, LoginRequest as AtprotoLoginRequest,
};
use rauthy_api_types::{
api_keys::*, auth_providers::*, blacklist::*, clients::*, events::*, fed_cm::*, generic::*,
groups::*, oidc::*, roles::*, scopes::*, sessions::*, users::*,
Expand All @@ -26,6 +29,10 @@ use utoipa::{openapi, OpenApi};
api_keys::get_api_key_test,
api_keys::put_api_key_secret,

atproto::get_client_metadata,
atproto::post_login,
atproto::post_callback,

auth_providers::post_providers,
auth_providers::post_provider,
auth_providers::post_provider_lookup,
Expand Down Expand Up @@ -152,6 +159,7 @@ use utoipa::{openapi, OpenApi};
components(
schemas(
entity::colors::Colors,
entity::atproto::DnsTxtResolver,
entity::fed_cm::FedCMAccount,
entity::fed_cm::FedCMAccounts,
entity::fed_cm::FedCMIdPBranding,
Expand Down Expand Up @@ -186,6 +194,8 @@ use utoipa::{openapi, OpenApi};
ErrorResponseType,

ApiKeyRequest,
AtprotoCallbackRequest,
AtprotoLoginRequest,
AuthCodeRequest,
AuthRequest,
IpBlacklistRequest,
Expand Down Expand Up @@ -294,6 +304,7 @@ use utoipa::{openapi, OpenApi};
(name = "generic", description = "Generic endpoints"),
(name = "webid", description = "WebID endpoints"),
(name = "fed_cm", description = "Experimental FedCM endpoints"),
(name = "atproto", description = "Experimental ATProto endpoints"),
(name = "deprecated", description = "Deprecated endpoints - will be removed in a future version"),
),
)]
Expand Down
40 changes: 40 additions & 0 deletions src/api_types/src/atproto.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use rauthy_common::constants::{RE_ALNUM, RE_URI};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validator::Validate;

#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
pub struct LoginRequest {
/// Validation:
/// `^(did:[a-z]+:[a-zA-Z0-9._:%-]*[a-zA-Z0-9._-]|([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)$`
pub at_id: String,
/// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$`
#[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))]
pub redirect_uri: String,
/// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$`
#[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))]
pub state: Option<String>,

/// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$`
#[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))]
pub pkce_challenge: String,
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
pub struct CallbackRequest {
/// Validation: `[a-zA-Z0-9]`
#[validate(regex(path = "*RE_ALNUM", code = "[a-zA-Z0-9]"))]
pub state: String,
/// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$`
#[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))]
pub code: String,
/// Validation: `[a-zA-Z0-9]`
#[validate(regex(path = "*RE_ALNUM", code = "[a-zA-Z0-9]"))]
pub iss: Option<String>,
/// Validation: `[a-zA-Z0-9]`
#[validate(regex(path = "*RE_ALNUM", code = "[a-zA-Z0-9]"))]
pub xsrf_token: String,
/// Validation: `[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$`
#[validate(regex(path = "*RE_URI", code = "[a-zA-Z0-9,.:/_-&?=~#!$'()*+%]+$"))]
pub pkce_verifier: String,
}
1 change: 1 addition & 0 deletions src/api_types/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod api_keys;
pub mod atproto;
pub mod auth_providers;
pub mod blacklist;
pub mod clients;
Expand Down
9 changes: 6 additions & 3 deletions src/bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ use rauthy_common::utils::UseDummyAddress;
use rauthy_common::{is_hiqlite, is_sqlite, password_hasher};
use rauthy_handlers::openapi::ApiDoc;
use rauthy_handlers::{
api_keys, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc, roles,
scopes, sessions, users,
api_keys, atproto, auth_providers, blacklist, clients, events, fed_cm, generic, groups, oidc,
roles, scopes, sessions, users,
};
use rauthy_middlewares::csrf_protection::CsrfProtectionMiddleware;
use rauthy_middlewares::ip_blacklist::RauthyIpBlacklistMiddleware;
Expand Down Expand Up @@ -556,7 +556,10 @@ async fn actix_main(app_state: web::Data<AppState>) -> std::io::Result<()> {
.service(oidc::get_well_known)
.service(generic::get_health)
.service(generic::get_ready)
.service(generic::get_static_assets),
.service(generic::get_static_assets)
.service(atproto::get_client_metadata)
.service(atproto::post_login)
.service(atproto::post_callback),
),
);

Expand Down
5 changes: 5 additions & 0 deletions src/common/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,11 @@ lazy_static! {
.parse::<bool>()
.expect("EXPERIMENTAL_FED_CM_ENABLE cannot be parsed to bool - bad format");

pub static ref ATPROTO_ENABLE: bool = env::var("ATPROTO_ENABLE")
.unwrap_or_else(|_| String::from("false"))
.parse::<bool>()
.expect("ATPROTO_ENABLE cannot be parsed to bool - bad format");

pub static ref REFRESH_TOKEN_LIFETIME: u16 = env::var("REFRESH_TOKEN_LIFETIME")
.unwrap_or_else(|_| String::from("48"))
.parse::<u16>()
Expand Down
7 changes: 7 additions & 0 deletions src/models/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ validator = { workspace = true }
webauthn-rs = { workspace = true }
webauthn-rs-proto = { workspace = true }

atrium-api = { workspace = true }
atrium-common = { workspace = true }
atrium-identity = { workspace = true }
atrium-oauth-client = { workspace = true }

hickory-resolver = { workspace = true }

[dev-dependencies]
pretty_assertions = "1"
rstest = "0.18.2"
Expand Down
6 changes: 6 additions & 0 deletions src/models/src/app_state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::email::EMail;
use crate::entity::atproto::AtprotoClient;
use crate::events::event::Event;
use crate::events::ip_blacklist_handler::IpBlacklistReq;
use crate::events::listener::EventRouterMsg;
Expand Down Expand Up @@ -31,6 +32,7 @@ pub struct AppState {
pub tx_events_router: flume::Sender<EventRouterMsg>,
pub tx_ip_blacklist: flume::Sender<IpBlacklistReq>,
pub webauthn: Arc<Webauthn>,
pub atproto: AtprotoClient,
}

impl AppState {
Expand Down Expand Up @@ -154,6 +156,9 @@ impl AppState {
.rp_name(&rp_name);
let webauthn = Arc::new(builder.build().expect("Invalid configuration"));

let atproto = AtprotoClient::new(&listen_scheme, &public_url)
.expect("failed to initialize atproto client");

Ok(Self {
public_url,
argon2_params,
Expand All @@ -170,6 +175,7 @@ impl AppState {
tx_events_router,
tx_ip_blacklist,
webauthn,
atproto,
})
}

Expand Down
Loading