Skip to content

Commit

Permalink
proxy: add jwks endpoint to control plane and mock providers (#9165)
Browse files Browse the repository at this point in the history
  • Loading branch information
conradludgate authored Sep 27, 2024
1 parent 42ef08d commit 43b2445
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 16 deletions.
16 changes: 16 additions & 0 deletions proxy/src/auth/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ pub(crate) trait TestBackend: Send + Sync + 'static {
fn get_allowed_ips_and_secret(
&self,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
fn dyn_clone(&self) -> Box<dyn TestBackend>;
}

#[cfg(test)]
impl Clone for Box<dyn TestBackend> {
fn clone(&self) -> Self {
TestBackend::dyn_clone(&**self)
}
}

impl std::fmt::Display for Backend<'_, (), ()> {
Expand Down Expand Up @@ -585,6 +593,14 @@ mod tests {
))
}

async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
_endpoint: crate::EndpointId,
) -> anyhow::Result<Vec<super::jwt::AuthRule>> {
unimplemented!()
}

async fn wake_compute(
&self,
_ctx: &RequestMonitoring,
Expand Down
38 changes: 36 additions & 2 deletions proxy/src/console/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ pub mod neon;
use super::messages::{ConsoleError, MetricsAuxInfo};
use crate::{
auth::{
backend::{ComputeCredentialKeys, ComputeUserInfo},
backend::{
jwt::{AuthRule, FetchAuthRules},
ComputeCredentialKeys, ComputeUserInfo,
},
IpPattern,
},
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
Expand All @@ -16,7 +19,7 @@ use crate::{
intern::ProjectIdInt,
metrics::ApiLockMetrics,
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
scram, EndpointCacheKey,
scram, EndpointCacheKey, EndpointId,
};
use dashmap::DashMap;
use std::{hash::Hash, sync::Arc, time::Duration};
Expand Down Expand Up @@ -334,6 +337,12 @@ pub(crate) trait Api {
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;

async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>>;

/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
&self,
Expand All @@ -343,6 +352,7 @@ pub(crate) trait Api {
}

#[non_exhaustive]
#[derive(Clone)]
pub enum ConsoleBackend {
/// Current Cloud API (V2).
Console(neon::Api),
Expand Down Expand Up @@ -386,6 +396,20 @@ impl Api for ConsoleBackend {
}
}

async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
match self {
Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(any(test, feature = "testing"))]
Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(test)]
Self::Test(_api) => Ok(vec![]),
}
}

async fn wake_compute(
&self,
ctx: &RequestMonitoring,
Expand Down Expand Up @@ -552,3 +576,13 @@ impl WakeComputePermit {
res
}
}

impl FetchAuthRules for ConsoleBackend {
async fn fetch_auth_rules(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.get_endpoint_jwks(ctx, endpoint).await
}
}
45 changes: 44 additions & 1 deletion proxy/src/console/provider/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
};
use crate::context::RequestMonitoring;
use crate::{
auth::backend::jwt::AuthRule, context::RequestMonitoring, intern::RoleNameInt, RoleName,
};
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use crate::{
Expand Down Expand Up @@ -118,6 +120,39 @@ impl Api {
})
}

async fn do_get_endpoint_jwks(&self, endpoint: EndpointId) -> anyhow::Result<Vec<AuthRule>> {
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;

let connection = tokio::spawn(connection);

let res = client.query(
"select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
&[&endpoint.as_str()],
)
.await?;

let mut rows = vec![];
for row in res {
rows.push(AuthRule {
id: row.get("id"),
jwks_url: url::Url::parse(row.get("jwks_url"))?,
audience: row.get("audience"),
role_names: row
.get::<_, Vec<String>>("role_names")
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
});
}

drop(client);
connection.await??;

Ok(rows)
}

async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
Expand Down Expand Up @@ -185,6 +220,14 @@ impl super::Api for Api {
))
}

async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(endpoint).await
}

#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,
Expand Down
90 changes: 81 additions & 9 deletions proxy/src/console/provider/neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,33 @@ use super::{
NodeInfo,
};
use crate::{
auth::backend::ComputeUserInfo,
auth::backend::{jwt::AuthRule, ComputeUserInfo},
compute,
console::messages::{ColdStartInfo, Reason},
console::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
http,
metrics::{CacheOutcome, Metrics},
rate_limiter::WakeComputeRateLimiter,
scram, EndpointCacheKey,
scram, EndpointCacheKey, EndpointId,
};
use crate::{cache::Cached, context::RequestMonitoring};
use ::http::{header::AUTHORIZATION, HeaderName};
use anyhow::bail;
use futures::TryFutureExt;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{debug, error, info, info_span, warn, Instrument};

const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");

#[derive(Clone)]
pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
jwt: String,
// put in a shared ref so we don't copy secrets all over in memory
jwt: Arc<str>,
}

impl Api {
Expand All @@ -38,7 +44,9 @@ impl Api {
locks: &'static ApiLocks<EndpointCacheKey>,
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
) -> Self {
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN").unwrap_or_default();
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
.unwrap_or_default()
.into();
Self {
endpoint,
caches,
Expand Down Expand Up @@ -71,9 +79,9 @@ impl Api {
async {
let request = self
.endpoint
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.get_path("proxy_get_role_secret")
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", application_name.as_str()),
Expand Down Expand Up @@ -125,6 +133,61 @@ impl Api {
.await
}

async fn do_get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &endpoint.normalize())
.await
{
bail!("endpoint not found");
}
let request_id = ctx.session_id().to_string();
async {
let request = self
.endpoint
.get_with_url(|url| {
url.path_segments_mut()
.push("endpoints")
.push(endpoint.as_str())
.push("jwks");
})
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.build()?;

info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");

let body = parse_body::<EndpointJwksResponse>(response).await?;

let rules = body
.jwks
.into_iter()
.map(|jwks| AuthRule {
id: jwks.id,
jwks_url: jwks.jwks_url,
audience: jwks.jwt_audience,
role_names: jwks.role_names,
})
.collect();

Ok(rules)
}
.map_err(crate::error::log_error)
.instrument(info_span!("http", id = request_id))
.await
}

async fn do_wake_compute(
&self,
ctx: &RequestMonitoring,
Expand All @@ -135,7 +198,7 @@ impl Api {
async {
let mut request_builder = self
.endpoint
.get("proxy_wake_compute")
.get_path("proxy_wake_compute")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
Expand Down Expand Up @@ -262,6 +325,15 @@ impl super::Api for Api {
))
}

#[tracing::instrument(skip_all)]
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(ctx, endpoint).await
}

#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,
Expand Down
16 changes: 12 additions & 4 deletions proxy/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,17 @@ impl Endpoint {

/// Return a [builder](RequestBuilder) for a `GET` request,
/// appending a single `path` segment to the base endpoint URL.
pub(crate) fn get(&self, path: &str) -> RequestBuilder {
pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
self.get_with_url(|u| {
u.path_segments_mut().push(path);
})
}

/// Return a [builder](RequestBuilder) for a `GET` request,
/// accepting a closure to modify the url path segments for more complex paths queries.
pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
let mut url = self.endpoint.clone();
url.path_segments_mut().push(path);
f(&mut url);
self.client.get(url.into_inner())
}

Expand Down Expand Up @@ -144,7 +152,7 @@ mod tests {

// Validate that this pattern makes sense.
let req = endpoint
.get("frobnicate")
.get_path("frobnicate")
.query(&[
("foo", Some("10")), // should be just `foo=10`
("bar", None), // shouldn't be passed at all
Expand All @@ -162,7 +170,7 @@ mod tests {
let endpoint = Endpoint::new(url, Client::new());

let req = endpoint
.get("frobnicate")
.get_path("frobnicate")
.query(&[("session_id", uuid::Uuid::nil())])
.build()?;

Expand Down
4 changes: 4 additions & 0 deletions proxy/src/proxy/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,10 @@ impl TestBackend for TestConnectMechanism {
{
unimplemented!("not used in tests")
}

fn dyn_clone(&self) -> Box<dyn TestBackend> {
Box::new(self.clone())
}
}

fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
Expand Down

1 comment on commit 43b2445

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5013 tests run: 4854 passed, 1 failed, 158 skipped (full report)


Failures on Postgres 17

# Run all failed tests locally:
scripts/pytest -vv -n $(nproc) -k "test_storage_controller_heartbeats[release-pg17-failure4]"
Flaky tests (10)

Postgres 17

Postgres 16

  • test_ondemand_wal_download_in_replication_slot_funcs: release-x86-64

Postgres 15

Postgres 14

Test coverage report is not available

The comment gets automatically updated with the latest test results
43b2445 at 2024-09-27T15:58:24.213Z :recycle:

Please sign in to comment.