Skip to content

feat(auth): Implement issuer allow-list and required-auth flag #2847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions crates/client-api/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashSet;
use std::time::{Duration, SystemTime};

use axum::extract::{Query, Request, State};
Expand Down Expand Up @@ -197,8 +198,12 @@ pub struct JwtKeyAuthProvider<TV: TokenValidator + Send + Sync> {
pub type DefaultJwtAuthProvider = JwtKeyAuthProvider<DefaultValidator>;

// Create a new AuthEnvironment using the default caching validator.
pub fn default_auth_environment(keys: JwtKeys, local_issuer: String) -> JwtKeyAuthProvider<DefaultValidator> {
let validator = new_validator(keys.public.clone(), local_issuer.clone());
pub fn default_auth_environment(
keys: JwtKeys,
local_issuer: String,
allowed_oidc_issuers: Option<HashSet<String>>,
) -> JwtKeyAuthProvider<DefaultValidator> {
let validator = new_validator(keys.public.clone(), local_issuer.clone(), allowed_oidc_issuers);
JwtKeyAuthProvider::new(keys, local_issuer, validator)
}

Expand Down Expand Up @@ -269,23 +274,32 @@ impl<S: NodeDelegate + Send + Sync> axum::extract::FromRequestParts<S> for Space
type Rejection = AuthorizationRejection;
async fn from_request_parts(parts: &mut request::Parts, state: &S) -> Result<Self, Self::Rejection> {
let Some(creds) = SpacetimeCreds::from_request_parts(parts)? else {
// No token was provided at all. This is a legitimate anonymous user.
return Ok(Self { auth: None });
};

let claims = state
.jwt_auth_provider()
.validator()
.validate_token(&creds.token)
.await
.map_err(AuthorizationRejection::Custom)?;

let auth = SpacetimeAuth {
creds,
identity: claims.identity,
subject: claims.subject,
issuer: claims.issuer,
};
Ok(Self { auth: Some(auth) })
// Credentials ARE present. They MUST now be validated successfully.
let validation_result = state.jwt_auth_provider().validator().validate_token(&creds.token).await;

match validation_result {
Ok(claims) => {
// The token is valid. Create the auth struct.
let auth = SpacetimeAuth {
creds,
identity: claims.identity,
subject: claims.subject,
issuer: claims.issuer,
};
Ok(Self { auth: Some(auth) })
}
Err(validation_error) => {
// The token was present but INVALID.
// This is a hard failure. We must reject the request.
// We are explicitly returning an Err that will halt the request
// with a 401 Unauthorized status.
Err(AuthorizationRejection::Custom(validation_error))
}
}
}
}

Expand Down Expand Up @@ -423,7 +437,17 @@ pub async fn anon_auth_middleware<S: ControlStateDelegate + NodeDelegate>(
mut req: Request,
next: Next,
) -> axum::response::Result<impl IntoResponse> {
let auth = auth.get_or_create(&worker_ctx).await?;
let auth = match auth.get() {
None => {
if worker_ctx.auth_required() {
log::warn!("Rejecting anonymous connection because --auth-required is set.");
Err(AuthorizationRejection::Required.into())
} else {
SpacetimeAuth::alloc(&worker_ctx).await
}
}
Some(authorization) => Ok(authorization),
}?;
req.extensions_mut().insert(auth.clone());
let resp = next.run(req).await;
Ok((auth.into_headers(), resp))
Expand Down
5 changes: 4 additions & 1 deletion crates/client-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub trait NodeDelegate: Send + Sync {

type JwtAuthProviderT: auth::JwtAuthProvider;
fn jwt_auth_provider(&self) -> &Self::JwtAuthProviderT;
fn auth_required(&self) -> bool;
/// Return the leader [`Host`] of `database_id`.
///
/// Returns `None` if the current leader is not hosted by this node.
Expand Down Expand Up @@ -363,7 +364,9 @@ impl<T: NodeDelegate + ?Sized> NodeDelegate for Arc<T> {
fn jwt_auth_provider(&self) -> &Self::JwtAuthProviderT {
(**self).jwt_auth_provider()
}

fn auth_required(&self) -> bool {
(**self).auth_required()
}
async fn leader(&self, database_id: u64) -> anyhow::Result<Option<Host>> {
(**self).leader(database_id).await
}
Expand Down
61 changes: 58 additions & 3 deletions crates/core/src/auth/token_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub use jsonwebtoken::{DecodingKey, EncodingKey};
use jwks::Jwks;
use lazy_static::lazy_static;
use serde::Serialize;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use thiserror;
Expand Down Expand Up @@ -118,13 +119,34 @@ where
}
}

pub type DefaultValidator = FullTokenValidator<CachingOidcTokenValidator>;
pub type DefaultValidator = FullTokenValidator<AllowedIssuersValidator<CachingOidcTokenValidator>>;

pub fn new_validator(
local_key: DecodingKey,
local_issuer: String,
// Accept an optional list of allowed OIDC issuers.
allowed_oidc_issuers: Option<HashSet<String>>,
) -> DefaultValidator {
// If the allow-list is empty or not provided, log a prominent security warning.
if allowed_oidc_issuers.is_none() || allowed_oidc_issuers.as_ref().unwrap().is_empty() {
log::warn!(
"SECURITY WARNING: No OIDC issuer allow-list is configured. \
Any valid OIDC token from ANY issuer will be accepted. \
This is NOT recommended for production environments. \
Please configure 'allowed_oidc_issuers'."
);
}

let caching_validator = CachingOidcTokenValidator::get_default();
let oidc_validator = AllowedIssuersValidator {
allowed_issuers: allowed_oidc_issuers,
inner_validator: caching_validator,
};

pub fn new_validator(local_key: DecodingKey, local_issuer: String) -> FullTokenValidator<CachingOidcTokenValidator> {
FullTokenValidator {
local_key,
local_issuer,
oidc_validator: CachingOidcTokenValidator::get_default(),
oidc_validator,
}
}

Expand Down Expand Up @@ -197,6 +219,39 @@ impl CachingOidcTokenValidator {
}
}

/// A validator that wraps another validator and adds an issuer allow-list.
/// If `allowed_issuers` is `None`, it will permit any issuer (with a warning).
pub struct AllowedIssuersValidator<T: TokenValidator> {
// We use an Option to clearly distinguish between "enforce this list"
// and "allow any issuer".
allowed_issuers: Option<HashSet<String>>,
inner_validator: T,
}

#[async_trait]
impl<T: TokenValidator + Send + Sync> TokenValidator for AllowedIssuersValidator<T> {
async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
if let Some(allowed) = &self.allowed_issuers {
// Get the issuer without full validation first.
let issuer = get_raw_issuer(token)?;

// Check if the issuer is in our allow-list.
if !allowed.contains(&issuer) {
log::warn!(
"Token validation failed: issuer '{}' is not in the allowed list.",
issuer
);
return Err(TokenValidationError::Other(anyhow::anyhow!(
"Issuer is not in the allowed list"
)));
}
}

// If the check passes (or if no list is configured), proceed with the inner validator.
self.inner_validator.validate_token(token).await
}
}

// Jwks fetcher for the async cache.
struct KeyFetcher;

Expand Down
23 changes: 18 additions & 5 deletions crates/standalone/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use spacetimedb_client_api_messages::name::{DomainName, InsertDomainResult, Regi
use spacetimedb_paths::server::{ModuleLogsDir, PidFile, ServerDataDir};
use spacetimedb_paths::standalone::StandaloneDataDirExt;
use spacetimedb_table::page_pool::PagePool;
use std::collections::HashSet;
use std::sync::Arc;

pub use spacetimedb_client_api::routes::subscribe::{BIN_PROTOCOL, TEXT_PROTOCOL};
Expand All @@ -40,6 +41,7 @@ pub struct StandaloneEnv {
metrics_registry: prometheus::Registry,
_pid_file: PidFile,
auth_provider: auth::DefaultJwtAuthProvider,
pub auth_required: bool,
}

impl StandaloneEnv {
Expand All @@ -48,6 +50,8 @@ impl StandaloneEnv {
certs: &CertificateAuthority,
data_dir: Arc<ServerDataDir>,
db_cores: JobCores,
allowed_oidc_issuers: Option<HashSet<String>>,
auth_required: bool,
) -> anyhow::Result<Arc<Self>> {
let _pid_file = data_dir.pid_file()?;
let meta_path = data_dir.metadata_toml();
Expand Down Expand Up @@ -75,7 +79,7 @@ impl StandaloneEnv {
let client_actor_index = ClientActorIndex::new();
let jwt_keys = certs.get_or_create_keys()?;

let auth_env = auth::default_auth_environment(jwt_keys, LOCALHOST.to_owned());
let auth_env = auth::default_auth_environment(jwt_keys, LOCALHOST.to_owned(), allowed_oidc_issuers);

let metrics_registry = prometheus::Registry::new();
metrics_registry.register(Box::new(&*WORKER_METRICS)).unwrap();
Expand All @@ -90,6 +94,7 @@ impl StandaloneEnv {
metrics_registry,
_pid_file,
auth_provider: auth_env,
auth_required,
}))
}

Expand Down Expand Up @@ -140,6 +145,10 @@ impl NodeDelegate for StandaloneEnv {
&self.auth_provider
}

fn auth_required(&self) -> bool {
self.auth_required
}

async fn leader(&self, database_id: u64) -> anyhow::Result<Option<Host>> {
let leader = match self.control_db.get_leader_replica_by_database(database_id) {
Some(leader) => leader,
Expand Down Expand Up @@ -519,11 +528,15 @@ mod tests {
page_pool_max_size: None,
};

let _env = StandaloneEnv::init(config, &ca, data_dir.clone(), Default::default()).await?;
// We pass `None` here to test the default behavior (allow all).
// TODO: Test with list of allowed OIDC issuers
let _env = StandaloneEnv::init(config, &ca, data_dir.clone(), Default::default(), None, false).await?;
// Ensure that we have a lock.
assert!(StandaloneEnv::init(config, &ca, data_dir.clone(), Default::default())
.await
.is_err());
assert!(
StandaloneEnv::init(config, &ca, data_dir.clone(), Default::default(), None, false)
.await
.is_err()
);

Ok(())
}
Expand Down
31 changes: 29 additions & 2 deletions crates/standalone/src/subcommands/start.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::collections::HashSet;
use std::sync::Arc;

use crate::StandaloneEnv;
use anyhow::Context;
use axum::extract::DefaultBodyLimit;
use clap::ArgAction::SetTrue;
use clap::{Arg, ArgMatches};
use clap::{Arg, ArgAction, ArgMatches};
use spacetimedb::config::{CertificateAuthority, ConfigFile};
use spacetimedb::db::{Config, Storage};
use spacetimedb::startup::{self, TracingOptions};
Expand Down Expand Up @@ -73,6 +74,20 @@ pub fn cli() -> clap::Command {
"The maximum size of the page pool in bytes. Should be a multiple of 64KiB. The default is 8GiB.",
),
)
.arg(
Arg::new("allowed-oidc-issuer")
.long("allowed-oidc-issuer")
.help("Allow tokens from this OIDC issuer. Can be specified multiple times. If not set, all issuers are allowed (NOT RECOMMENDED FOR PRODUCTION).")
// This allows the user to provide the flag multiple times,
// e.g., --allowed-oidc-issuer https://a.com --allowed-oidc-issuer https://b.com
.action(ArgAction::Append)
)
.arg(
Arg::new("auth-required")
.long("auth-required")
.action(SetTrue)
.help("If specified, anonymous user creation is disabled and all connections must provide a valid JWT.")
)
// .after_help("Run `spacetime help start` for more detailed information.")
}

Expand Down Expand Up @@ -104,6 +119,10 @@ pub async fn exec(args: &ArgMatches, db_cores: JobCores) -> anyhow::Result<()> {
storage,
page_pool_max_size,
};
let allowed_oidc_issuers: Option<HashSet<String>> = args
.get_many::<String>("allowed-oidc-issuer")
.map(|vals| vals.cloned().collect());
let auth_required = args.get_flag("auth-required");

banner();
let exe_name = std::env::current_exe()?;
Expand Down Expand Up @@ -144,7 +163,15 @@ pub async fn exec(args: &ArgMatches, db_cores: JobCores) -> anyhow::Result<()> {
.context("cannot omit --jwt-{pub,priv}-key-path when those options are not specified in config.toml")?;

let data_dir = Arc::new(data_dir.clone());
let ctx = StandaloneEnv::init(db_config, &certs, data_dir, db_cores).await?;
let ctx = StandaloneEnv::init(
db_config,
&certs,
data_dir,
db_cores,
allowed_oidc_issuers,
auth_required,
)
.await?;
worker_metrics::spawn_jemalloc_stats(listen_addr.clone());
worker_metrics::spawn_tokio_stats(listen_addr.clone());
worker_metrics::spawn_page_pool_stats(listen_addr.clone(), ctx.page_pool().clone());
Expand Down
14 changes: 10 additions & 4 deletions crates/testing/src/modules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,16 @@ impl CompiledModule {
};

let certs = CertificateAuthority::in_cli_config_dir(&paths.cli_config_dir);
let env =
spacetimedb_standalone::StandaloneEnv::init(config, &certs, paths.data_dir.into(), Default::default())
.await
.unwrap();
let env = spacetimedb_standalone::StandaloneEnv::init(
config,
&certs,
paths.data_dir.into(),
Default::default(),
None,
false,
)
.await
.unwrap();
// TODO: Fix this when we update identity generation.
let identity = Identity::ZERO;
let db_identity = SpacetimeAuth::alloc(&env).await.unwrap().identity;
Expand Down
Loading