Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Add GraphQL mutation to do self-service user registration #3050

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ impl Options {
homeserver_connection.clone(),
site_config.clone(),
password_manager.clone(),
http_client_factory.clone(),
url_builder.clone(),
);

let state = {
Expand Down
6 changes: 5 additions & 1 deletion crates/handlers/src/captcha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::net::IpAddr;

use async_graphql::InputObject;
use axum::BoxError;
use hyper::Request;
use mas_axum_utils::http_client_factory::HttpClientFactory;
Expand Down Expand Up @@ -58,8 +59,11 @@ pub enum Error {
RequestFailed(#[source] BoxError),
}

/// Form (or GraphQL input) containing a CAPTCHA provider's response
/// for one of the providers.
#[allow(clippy::struct_field_names)]
#[derive(Debug, Deserialize, Default)]
#[derive(Debug, Deserialize, Default, InputObject)]
#[graphql(input_name = "CaptchaForm")]
#[serde(rename_all = "kebab-case")]
pub struct Form {
g_recaptcha_response: Option<String>,
Expand Down
40 changes: 33 additions & 7 deletions crates/handlers/src/graphql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ use futures_util::TryStreamExt;
use headers::{authorization::Bearer, Authorization, ContentType, HeaderValue};
use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{
cookies::CookieJar, sentry::SentryEventID, FancyError, SessionInfo, SessionInfoExt,
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID, FancyError,
SessionInfo, SessionInfoExt,
};
use mas_data_model::{BrowserSession, Session, SiteConfig, User};
use mas_data_model::{BrowserSession, Session, SiteConfig, User, UserAgent};
use mas_matrix::HomeserverConnection;
use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Clock, RepositoryError, SystemClock};
use mas_storage_pg::PgRepository;
use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME};
Expand All @@ -60,7 +62,9 @@ use self::{
mutations::Mutation,
query::Query,
};
use crate::{impl_from_error_for_route, passwords::PasswordManager, BoundActivityTracker};
use crate::{
impl_from_error_for_route, passwords::PasswordManager, BoundActivityTracker, PreferredLanguage,
};

#[cfg(test)]
mod tests;
Expand All @@ -71,6 +75,8 @@ struct GraphQLState {
policy_factory: Arc<PolicyFactory>,
site_config: SiteConfig,
password_manager: PasswordManager,
http_client_factory: HttpClientFactory,
url_builder: UrlBuilder,
}

#[async_trait]
Expand Down Expand Up @@ -111,6 +117,14 @@ impl state::State for GraphQLState {
let rng = ChaChaRng::from_rng(rng).expect("Failed to seed rng");
Box::new(rng)
}

fn http_client_factory(&self) -> &HttpClientFactory {
&self.http_client_factory
}

fn url_builder(&self) -> &UrlBuilder {
&self.url_builder
}
}

#[must_use]
Expand All @@ -120,13 +134,17 @@ pub fn schema(
homeserver_connection: impl HomeserverConnection<Error = anyhow::Error> + 'static,
site_config: SiteConfig,
password_manager: PasswordManager,
http_client_factory: HttpClientFactory,
url_builder: UrlBuilder,
) -> Schema {
let state = GraphQLState {
pool: pool.clone(),
policy_factory: Arc::clone(policy_factory),
homeserver_connection: Arc::new(homeserver_connection),
site_config,
password_manager,
http_client_factory,
url_builder,
};
let state: BoxState = Box::new(state);

Expand Down Expand Up @@ -281,12 +299,14 @@ async fn get_requester(

pub async fn post(
AxumState(schema): AxumState<Schema>,
PreferredLanguage(locale): PreferredLanguage,
clock: BoxClock,
repo: BoxRepository,
activity_tracker: BoundActivityTracker,
cookie_jar: CookieJar,
content_type: Option<TypedHeader<ContentType>>,
authorization: Option<TypedHeader<Authorization<Bearer>>>,
user_agent: Option<TypedHeader<headers::UserAgent>>,
body: Body,
) -> Result<impl IntoResponse, RouteError> {
let body = body.into_data_stream();
Expand All @@ -304,8 +324,11 @@ pub async fn post(
.into_async_read(),
MultipartOptions::default(),
)
.await?
.data(requester); // XXX: this should probably return another error response?
.await? // XXX: this should probably return another error response?
.data(requester)
.data(user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned())))
.data(locale)
.data(activity_tracker);

let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
Expand All @@ -328,6 +351,7 @@ pub async fn get(
activity_tracker: BoundActivityTracker,
cookie_jar: CookieJar,
authorization: Option<TypedHeader<Authorization<Bearer>>>,
user_agent: Option<TypedHeader<headers::UserAgent>>,
RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> {
let token = authorization
Expand All @@ -336,8 +360,10 @@ pub async fn get(
let (session_info, _cookie_jar) = cookie_jar.session_info();
let requester = get_requester(&clock, &activity_tracker, repo, session_info, token).await?;

let request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
let request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?
.data(requester)
.data(activity_tracker)
.data(user_agent);

let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
Expand Down
Loading
Loading