From 22d8834474d1f619b6ed351fd80033b4a064bb21 Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Thu, 17 Oct 2024 13:38:24 +0300 Subject: [PATCH] proxy: move the connection pools to separate file (#9398) First PR for #9284 Start unification of the client and connection pool interfaces: - Exclude the 'global_connections_count' out from the get_conn_entry() - Move remote connection pools to the conn_pool_lib as a reference - Unify clients among all the conn pools --- proxy/src/serverless/backend.rs | 13 +- proxy/src/serverless/conn_pool.rs | 585 ++---------------------- proxy/src/serverless/conn_pool_lib.rs | 562 +++++++++++++++++++++++ proxy/src/serverless/http_conn_pool.rs | 50 +- proxy/src/serverless/local_conn_pool.rs | 109 ++--- proxy/src/serverless/mod.rs | 5 +- proxy/src/serverless/sql_over_http.rs | 15 +- 7 files changed, 704 insertions(+), 635 deletions(-) create mode 100644 proxy/src/serverless/conn_pool_lib.rs diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index a180c4c2ed09..82e81dbcfef6 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -11,8 +11,9 @@ use tokio::net::{lookup_host, TcpStream}; use tracing::field::display; use tracing::{debug, info}; -use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}; -use super::http_conn_pool::{self, poll_http2_client}; +use super::conn_pool::poll_client; +use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool}; +use super::http_conn_pool::{self, poll_http2_client, Send}; use super::local_conn_pool::{self, LocalClient, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; @@ -31,7 +32,7 @@ use crate::rate_limiter::EndpointRateLimiter; use crate::{compute, EndpointId, Host}; pub(crate) struct PoolingBackend { - pub(crate) http_conn_pool: Arc, + pub(crate) http_conn_pool: Arc>, pub(crate) local_pool: Arc>, pub(crate) pool: Arc>, pub(crate) config: &'static ProxyConfig, @@ -199,7 +200,7 @@ impl PoolingBackend { &self, ctx: &RequestMonitoring, conn_info: ConnInfo, - ) -> Result { + ) -> Result, HttpConnError> { info!("pool: looking for an existing connection"); if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) { return Ok(client); @@ -481,7 +482,7 @@ impl ConnectMechanism for TokioMechanism { } struct HyperMechanism { - pool: Arc, + pool: Arc>, conn_info: ConnInfo, conn_id: uuid::Uuid, @@ -491,7 +492,7 @@ struct HyperMechanism { #[async_trait] impl ConnectMechanism for HyperMechanism { - type Connection = http_conn_pool::Client; + type Connection = http_conn_pool::Client; type ConnectError = HttpConnError; type Error = HttpConnError; diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index aa869ff1c0a5..b97c6565101e 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -1,31 +1,29 @@ -use std::collections::HashMap; use std::fmt; -use std::ops::Deref; use std::pin::pin; -use std::sync::atomic::{self, AtomicUsize}; use std::sync::{Arc, Weak}; use std::task::{ready, Poll}; -use std::time::Duration; -use dashmap::DashMap; use futures::future::poll_fn; use futures::Future; -use parking_lot::RwLock; -use rand::Rng; use smallvec::SmallVec; use tokio::time::Instant; use tokio_postgres::tls::NoTlsStream; -use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; +use tokio_postgres::{AsyncMessage, Socket}; use tokio_util::sync::CancellationToken; -use tracing::{debug, error, info, info_span, warn, Instrument, Span}; +use tracing::{error, info, info_span, warn, Instrument}; -use super::backend::HttpConnError; -use crate::auth::backend::ComputeUserInfo; use crate::context::RequestMonitoring; -use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; -use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; -use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; -use crate::{DbName, EndpointCacheKey, RoleName}; +use crate::control_plane::messages::MetricsAuxInfo; +use crate::metrics::Metrics; + +use super::conn_pool_lib::{Client, ClientInnerExt, ConnInfo, GlobalConnPool}; + +#[cfg(test)] +use { + super::conn_pool_lib::GlobalConnPoolOptions, + crate::auth::backend::ComputeUserInfo, + std::{sync::atomic, time::Duration}, +}; #[derive(Debug, Clone)] pub(crate) struct ConnInfoWithAuth { @@ -33,34 +31,12 @@ pub(crate) struct ConnInfoWithAuth { pub(crate) auth: AuthData, } -#[derive(Debug, Clone)] -pub(crate) struct ConnInfo { - pub(crate) user_info: ComputeUserInfo, - pub(crate) dbname: DbName, -} - #[derive(Debug, Clone)] pub(crate) enum AuthData { Password(SmallVec<[u8; 16]>), Jwt(String), } -impl ConnInfo { - // hm, change to hasher to avoid cloning? - pub(crate) fn db_and_user(&self) -> (DbName, RoleName) { - (self.dbname.clone(), self.user_info.user.clone()) - } - - pub(crate) fn endpoint_cache_key(&self) -> Option { - // We don't want to cache http connections for ephemeral endpoints. - if self.user_info.options.is_ephemeral() { - None - } else { - Some(self.user_info.endpoint_cache_key()) - } - } -} - impl fmt::Display for ConnInfo { // use custom display to avoid logging password fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -75,402 +51,6 @@ impl fmt::Display for ConnInfo { } } -struct ConnPoolEntry { - conn: ClientInner, - _last_access: std::time::Instant, -} - -// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool -// Number of open connections is limited by the `max_conns_per_endpoint`. -pub(crate) struct EndpointConnPool { - pools: HashMap<(DbName, RoleName), DbUserConnPool>, - total_conns: usize, - max_conns: usize, - _guard: HttpEndpointPoolsGuard<'static>, - global_connections_count: Arc, - global_pool_size_max_conns: usize, -} - -impl EndpointConnPool { - fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option> { - let Self { - pools, - total_conns, - global_connections_count, - .. - } = self; - pools.get_mut(&db_user).and_then(|pool_entries| { - pool_entries.get_conn_entry(total_conns, global_connections_count.clone()) - }) - } - - fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool { - let Self { - pools, - total_conns, - global_connections_count, - .. - } = self; - if let Some(pool) = pools.get_mut(&db_user) { - let old_len = pool.conns.len(); - pool.conns.retain(|conn| conn.conn.conn_id != conn_id); - let new_len = pool.conns.len(); - let removed = old_len - new_len; - if removed > 0 { - global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .dec_by(removed as i64); - } - *total_conns -= removed; - removed > 0 - } else { - false - } - } - - fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInner) { - let conn_id = client.conn_id; - - if client.is_closed() { - info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed"); - return; - } - let global_max_conn = pool.read().global_pool_size_max_conns; - if pool - .read() - .global_connections_count - .load(atomic::Ordering::Relaxed) - >= global_max_conn - { - info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full"); - return; - } - - // return connection to the pool - let mut returned = false; - let mut per_db_size = 0; - let total_conns = { - let mut pool = pool.write(); - - if pool.total_conns < pool.max_conns { - let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default(); - pool_entries.conns.push(ConnPoolEntry { - conn: client, - _last_access: std::time::Instant::now(), - }); - - returned = true; - per_db_size = pool_entries.conns.len(); - - pool.total_conns += 1; - pool.global_connections_count - .fetch_add(1, atomic::Ordering::Relaxed); - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .inc(); - } - - pool.total_conns - }; - - // do logging outside of the mutex - if returned { - info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); - } else { - info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); - } - } -} - -impl Drop for EndpointConnPool { - fn drop(&mut self) { - if self.total_conns > 0 { - self.global_connections_count - .fetch_sub(self.total_conns, atomic::Ordering::Relaxed); - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .dec_by(self.total_conns as i64); - } - } -} - -pub(crate) struct DbUserConnPool { - conns: Vec>, -} - -impl Default for DbUserConnPool { - fn default() -> Self { - Self { conns: Vec::new() } - } -} - -impl DbUserConnPool { - fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { - let old_len = self.conns.len(); - - self.conns.retain(|conn| !conn.conn.is_closed()); - - let new_len = self.conns.len(); - let removed = old_len - new_len; - *conns -= removed; - removed - } - - fn get_conn_entry( - &mut self, - conns: &mut usize, - global_connections_count: Arc, - ) -> Option> { - let mut removed = self.clear_closed_clients(conns); - let conn = self.conns.pop(); - if conn.is_some() { - *conns -= 1; - removed += 1; - } - global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .dec_by(removed as i64); - conn - } -} - -pub(crate) struct GlobalConnPool { - // endpoint -> per-endpoint connection pool - // - // That should be a fairly conteded map, so return reference to the per-endpoint - // pool as early as possible and release the lock. - global_pool: DashMap>>>, - - /// Number of endpoint-connection pools - /// - /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each. - /// That seems like far too much effort, so we're using a relaxed increment counter instead. - /// It's only used for diagnostics. - global_pool_size: AtomicUsize, - - /// Total number of connections in the pool - global_connections_count: Arc, - - config: &'static crate::config::HttpConfig, -} - -#[derive(Debug, Clone, Copy)] -pub struct GlobalConnPoolOptions { - // Maximum number of connections per one endpoint. - // Can mix different (dbname, username) connections. - // When running out of free slots for a particular endpoint, - // falls back to opening a new connection for each request. - pub max_conns_per_endpoint: usize, - - pub gc_epoch: Duration, - - pub pool_shards: usize, - - pub idle_timeout: Duration, - - pub opt_in: bool, - - // Total number of connections in the pool. - pub max_total_conns: usize, -} - -impl GlobalConnPool { - pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { - let shards = config.pool_options.pool_shards; - Arc::new(Self { - global_pool: DashMap::with_shard_amount(shards), - global_pool_size: AtomicUsize::new(0), - config, - global_connections_count: Arc::new(AtomicUsize::new(0)), - }) - } - - #[cfg(test)] - pub(crate) fn get_global_connections_count(&self) -> usize { - self.global_connections_count - .load(atomic::Ordering::Relaxed) - } - - pub(crate) fn get_idle_timeout(&self) -> Duration { - self.config.pool_options.idle_timeout - } - - pub(crate) fn shutdown(&self) { - // drops all strong references to endpoint-pools - self.global_pool.clear(); - } - - pub(crate) async fn gc_worker(&self, mut rng: impl Rng) { - let epoch = self.config.pool_options.gc_epoch; - let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32); - loop { - interval.tick().await; - - let shard = rng.gen_range(0..self.global_pool.shards().len()); - self.gc(shard); - } - } - - fn gc(&self, shard: usize) { - debug!(shard, "pool: performing epoch reclamation"); - - // acquire a random shard lock - let mut shard = self.global_pool.shards()[shard].write(); - - let timer = Metrics::get() - .proxy - .http_pool_reclaimation_lag_seconds - .start_timer(); - let current_len = shard.len(); - let mut clients_removed = 0; - shard.retain(|endpoint, x| { - // if the current endpoint pool is unique (no other strong or weak references) - // then it is currently not in use by any connections. - if let Some(pool) = Arc::get_mut(x.get_mut()) { - let EndpointConnPool { - pools, total_conns, .. - } = pool.get_mut(); - - // ensure that closed clients are removed - for db_pool in pools.values_mut() { - clients_removed += db_pool.clear_closed_clients(total_conns); - } - - // we only remove this pool if it has no active connections - if *total_conns == 0 { - info!("pool: discarding pool for endpoint {endpoint}"); - return false; - } - } - - true - }); - - let new_len = shard.len(); - drop(shard); - timer.observe(); - - // Do logging outside of the lock. - if clients_removed > 0 { - let size = self - .global_connections_count - .fetch_sub(clients_removed, atomic::Ordering::Relaxed) - - clients_removed; - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .dec_by(clients_removed as i64); - info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}"); - } - let removed = current_len - new_len; - - if removed > 0 { - let global_pool_size = self - .global_pool_size - .fetch_sub(removed, atomic::Ordering::Relaxed) - - removed; - info!("pool: performed global pool gc. size now {global_pool_size}"); - } - } - - pub(crate) fn get( - self: &Arc, - ctx: &RequestMonitoring, - conn_info: &ConnInfo, - ) -> Result>, HttpConnError> { - let mut client: Option> = None; - let Some(endpoint) = conn_info.endpoint_cache_key() else { - return Ok(None); - }; - - let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint); - if let Some(entry) = endpoint_pool - .write() - .get_conn_entry(conn_info.db_and_user()) - { - client = Some(entry.conn); - } - let endpoint_pool = Arc::downgrade(&endpoint_pool); - - // ok return cached connection if found and establish a new one otherwise - if let Some(client) = client { - if client.is_closed() { - info!("pool: cached connection '{conn_info}' is closed, opening a new one"); - return Ok(None); - } - tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); - tracing::Span::current().record( - "pid", - tracing::field::display(client.inner.get_process_id()), - ); - info!( - cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), - "pool: reusing connection '{conn_info}'" - ); - client.session.send(ctx.session_id())?; - ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); - ctx.success(); - return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); - } - Ok(None) - } - - fn get_or_create_endpoint_pool( - self: &Arc, - endpoint: &EndpointCacheKey, - ) -> Arc>> { - // fast path - if let Some(pool) = self.global_pool.get(endpoint) { - return pool.clone(); - } - - // slow path - let new_pool = Arc::new(RwLock::new(EndpointConnPool { - pools: HashMap::new(), - total_conns: 0, - max_conns: self.config.pool_options.max_conns_per_endpoint, - _guard: Metrics::get().proxy.http_endpoint_pools.guard(), - global_connections_count: self.global_connections_count.clone(), - global_pool_size_max_conns: self.config.pool_options.max_total_conns, - })); - - // find or create a pool for this endpoint - let mut created = false; - let pool = self - .global_pool - .entry(endpoint.clone()) - .or_insert_with(|| { - created = true; - new_pool - }) - .clone(); - - // log new global pool size - if created { - let global_pool_size = self - .global_pool_size - .fetch_add(1, atomic::Ordering::Relaxed) - + 1; - info!( - "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}" - ); - } - - pool - } -} - pub(crate) fn poll_client( global_pool: Arc>, ctx: &RequestMonitoring, @@ -574,7 +154,7 @@ pub(crate) fn poll_client( } .instrument(span)); - let inner = ClientInner { + let inner = ClientInnerRemote { inner: client, session: tx, cancel, @@ -584,7 +164,7 @@ pub(crate) fn poll_client( Client::new(inner, conn_info, pool_clone) } -struct ClientInner { +pub(crate) struct ClientInnerRemote { inner: C, session: tokio::sync::watch::Sender, cancel: CancellationToken, @@ -592,131 +172,36 @@ struct ClientInner { conn_id: uuid::Uuid, } -impl Drop for ClientInner { - fn drop(&mut self) { - // on client drop, tell the conn to shut down - self.cancel.cancel(); +impl ClientInnerRemote { + pub(crate) fn inner_mut(&mut self) -> &mut C { + &mut self.inner } -} - -pub(crate) trait ClientInnerExt: Sync + Send + 'static { - fn is_closed(&self) -> bool; - fn get_process_id(&self) -> i32; -} -impl ClientInnerExt for tokio_postgres::Client { - fn is_closed(&self) -> bool { - self.is_closed() + pub(crate) fn inner(&self) -> &C { + &self.inner } - fn get_process_id(&self) -> i32 { - self.get_process_id() - } -} -impl ClientInner { - pub(crate) fn is_closed(&self) -> bool { - self.inner.is_closed() + pub(crate) fn session(&mut self) -> &mut tokio::sync::watch::Sender { + &mut self.session } -} -impl Client { - pub(crate) fn metrics(&self) -> Arc { - let aux = &self.inner.as_ref().unwrap().aux; - USAGE_METRICS.register(Ids { - endpoint_id: aux.endpoint_id, - branch_id: aux.branch_id, - }) + pub(crate) fn aux(&self) -> &MetricsAuxInfo { + &self.aux } -} - -pub(crate) struct Client { - span: Span, - inner: Option>, - conn_info: ConnInfo, - pool: Weak>>, -} -pub(crate) struct Discard<'a, C: ClientInnerExt> { - conn_info: &'a ConnInfo, - pool: &'a mut Weak>>, -} - -impl Client { - pub(self) fn new( - inner: ClientInner, - conn_info: ConnInfo, - pool: Weak>>, - ) -> Self { - Self { - inner: Some(inner), - span: Span::current(), - conn_info, - pool, - } - } - pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) { - let Self { - inner, - pool, - conn_info, - span: _, - } = self; - let inner = inner.as_mut().expect("client inner should not be removed"); - (&mut inner.inner, Discard { conn_info, pool }) + pub(crate) fn get_conn_id(&self) -> uuid::Uuid { + self.conn_id } -} -impl Discard<'_, C> { - pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { - let conn_info = &self.conn_info; - if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { - info!("pool: throwing away connection '{conn_info}' because connection is not idle"); - } - } - pub(crate) fn discard(&mut self) { - let conn_info = &self.conn_info; - if std::mem::take(self.pool).strong_count() > 0 { - info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); - } - } -} - -impl Deref for Client { - type Target = C; - - fn deref(&self) -> &Self::Target { - &self - .inner - .as_ref() - .expect("client inner should not be removed") - .inner - } -} - -impl Client { - fn do_drop(&mut self) -> Option { - let conn_info = self.conn_info.clone(); - let client = self - .inner - .take() - .expect("client inner should not be removed"); - if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() { - let current_span = self.span.clone(); - // return connection to the pool - return Some(move || { - let _span = current_span.enter(); - EndpointConnPool::put(&conn_pool, &conn_info, client); - }); - } - None + pub(crate) fn is_closed(&self) -> bool { + self.inner.is_closed() } } -impl Drop for Client { +impl Drop for ClientInnerRemote { fn drop(&mut self) { - if let Some(drop) = self.do_drop() { - tokio::task::spawn_blocking(drop); - } + // on client drop, tell the conn to shut down + self.cancel.cancel(); } } @@ -745,12 +230,12 @@ mod tests { } } - fn create_inner() -> ClientInner { + fn create_inner() -> ClientInnerRemote { create_inner_with(MockClient::new(false)) } - fn create_inner_with(client: MockClient) -> ClientInner { - ClientInner { + fn create_inner_with(client: MockClient) -> ClientInnerRemote { + ClientInnerRemote { inner: client, session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()), cancel: CancellationToken::new(), @@ -797,7 +282,7 @@ mod tests { { let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); assert_eq!(0, pool.get_global_connections_count()); - client.inner().1.discard(); + client.inner_mut().1.discard(); // Discard should not add the connection from the pool. assert_eq!(0, pool.get_global_connections_count()); } diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs new file mode 100644 index 000000000000..6e964ce8789f --- /dev/null +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -0,0 +1,562 @@ +use dashmap::DashMap; +use parking_lot::RwLock; +use rand::Rng; +use std::{collections::HashMap, sync::Arc, sync::Weak, time::Duration}; +use std::{ + ops::Deref, + sync::atomic::{self, AtomicUsize}, +}; +use tokio_postgres::ReadyForQueryStatus; + +use crate::control_plane::messages::ColdStartInfo; +use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; +use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; +use crate::{ + auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName, +}; + +use super::conn_pool::ClientInnerRemote; +use tracing::info; +use tracing::{debug, Span}; + +use super::backend::HttpConnError; + +#[derive(Debug, Clone)] +pub(crate) struct ConnInfo { + pub(crate) user_info: ComputeUserInfo, + pub(crate) dbname: DbName, +} + +impl ConnInfo { + // hm, change to hasher to avoid cloning? + pub(crate) fn db_and_user(&self) -> (DbName, RoleName) { + (self.dbname.clone(), self.user_info.user.clone()) + } + + pub(crate) fn endpoint_cache_key(&self) -> Option { + // We don't want to cache http connections for ephemeral endpoints. + if self.user_info.options.is_ephemeral() { + None + } else { + Some(self.user_info.endpoint_cache_key()) + } + } +} + +pub(crate) struct ConnPoolEntry { + pub(crate) conn: ClientInnerRemote, + pub(crate) _last_access: std::time::Instant, +} + +// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool +// Number of open connections is limited by the `max_conns_per_endpoint`. +pub(crate) struct EndpointConnPool { + pools: HashMap<(DbName, RoleName), DbUserConnPool>, + total_conns: usize, + max_conns: usize, + _guard: HttpEndpointPoolsGuard<'static>, + global_connections_count: Arc, + global_pool_size_max_conns: usize, +} + +impl EndpointConnPool { + fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option> { + let Self { + pools, + total_conns, + global_connections_count, + .. + } = self; + pools.get_mut(&db_user).and_then(|pool_entries| { + let (entry, removed) = pool_entries.get_conn_entry(total_conns); + global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); + entry + }) + } + + pub(crate) fn remove_client( + &mut self, + db_user: (DbName, RoleName), + conn_id: uuid::Uuid, + ) -> bool { + let Self { + pools, + total_conns, + global_connections_count, + .. + } = self; + if let Some(pool) = pools.get_mut(&db_user) { + let old_len = pool.conns.len(); + pool.conns.retain(|conn| conn.conn.get_conn_id() != conn_id); + let new_len = pool.conns.len(); + let removed = old_len - new_len; + if removed > 0 { + global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(removed as i64); + } + *total_conns -= removed; + removed > 0 + } else { + false + } + } + + pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInnerRemote) { + let conn_id = client.get_conn_id(); + + if client.is_closed() { + info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed"); + return; + } + + let global_max_conn = pool.read().global_pool_size_max_conns; + if pool + .read() + .global_connections_count + .load(atomic::Ordering::Relaxed) + >= global_max_conn + { + info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full"); + return; + } + + // return connection to the pool + let mut returned = false; + let mut per_db_size = 0; + let total_conns = { + let mut pool = pool.write(); + + if pool.total_conns < pool.max_conns { + let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default(); + pool_entries.conns.push(ConnPoolEntry { + conn: client, + _last_access: std::time::Instant::now(), + }); + + returned = true; + per_db_size = pool_entries.conns.len(); + + pool.total_conns += 1; + pool.global_connections_count + .fetch_add(1, atomic::Ordering::Relaxed); + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .inc(); + } + + pool.total_conns + }; + + // do logging outside of the mutex + if returned { + info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); + } else { + info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); + } + } +} + +impl Drop for EndpointConnPool { + fn drop(&mut self) { + if self.total_conns > 0 { + self.global_connections_count + .fetch_sub(self.total_conns, atomic::Ordering::Relaxed); + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(self.total_conns as i64); + } + } +} + +pub(crate) struct DbUserConnPool { + pub(crate) conns: Vec>, +} + +impl Default for DbUserConnPool { + fn default() -> Self { + Self { conns: Vec::new() } + } +} + +impl DbUserConnPool { + fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { + let old_len = self.conns.len(); + + self.conns.retain(|conn| !conn.conn.is_closed()); + + let new_len = self.conns.len(); + let removed = old_len - new_len; + *conns -= removed; + removed + } + + pub(crate) fn get_conn_entry( + &mut self, + conns: &mut usize, + ) -> (Option>, usize) { + let mut removed = self.clear_closed_clients(conns); + let conn = self.conns.pop(); + if conn.is_some() { + *conns -= 1; + removed += 1; + } + + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(removed as i64); + + (conn, removed) + } +} + +pub(crate) struct GlobalConnPool { + // endpoint -> per-endpoint connection pool + // + // That should be a fairly conteded map, so return reference to the per-endpoint + // pool as early as possible and release the lock. + global_pool: DashMap>>>, + + /// Number of endpoint-connection pools + /// + /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each. + /// That seems like far too much effort, so we're using a relaxed increment counter instead. + /// It's only used for diagnostics. + global_pool_size: AtomicUsize, + + /// Total number of connections in the pool + global_connections_count: Arc, + + config: &'static crate::config::HttpConfig, +} + +#[derive(Debug, Clone, Copy)] +pub struct GlobalConnPoolOptions { + // Maximum number of connections per one endpoint. + // Can mix different (dbname, username) connections. + // When running out of free slots for a particular endpoint, + // falls back to opening a new connection for each request. + pub max_conns_per_endpoint: usize, + + pub gc_epoch: Duration, + + pub pool_shards: usize, + + pub idle_timeout: Duration, + + pub opt_in: bool, + + // Total number of connections in the pool. + pub max_total_conns: usize, +} + +impl GlobalConnPool { + pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { + let shards = config.pool_options.pool_shards; + Arc::new(Self { + global_pool: DashMap::with_shard_amount(shards), + global_pool_size: AtomicUsize::new(0), + config, + global_connections_count: Arc::new(AtomicUsize::new(0)), + }) + } + + #[cfg(test)] + pub(crate) fn get_global_connections_count(&self) -> usize { + self.global_connections_count + .load(atomic::Ordering::Relaxed) + } + + pub(crate) fn get_idle_timeout(&self) -> Duration { + self.config.pool_options.idle_timeout + } + + pub(crate) fn shutdown(&self) { + // drops all strong references to endpoint-pools + self.global_pool.clear(); + } + + pub(crate) async fn gc_worker(&self, mut rng: impl Rng) { + let epoch = self.config.pool_options.gc_epoch; + let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32); + loop { + interval.tick().await; + + let shard = rng.gen_range(0..self.global_pool.shards().len()); + self.gc(shard); + } + } + + pub(crate) fn gc(&self, shard: usize) { + debug!(shard, "pool: performing epoch reclamation"); + + // acquire a random shard lock + let mut shard = self.global_pool.shards()[shard].write(); + + let timer = Metrics::get() + .proxy + .http_pool_reclaimation_lag_seconds + .start_timer(); + let current_len = shard.len(); + let mut clients_removed = 0; + shard.retain(|endpoint, x| { + // if the current endpoint pool is unique (no other strong or weak references) + // then it is currently not in use by any connections. + if let Some(pool) = Arc::get_mut(x.get_mut()) { + let EndpointConnPool { + pools, total_conns, .. + } = pool.get_mut(); + + // ensure that closed clients are removed + for db_pool in pools.values_mut() { + clients_removed += db_pool.clear_closed_clients(total_conns); + } + + // we only remove this pool if it has no active connections + if *total_conns == 0 { + info!("pool: discarding pool for endpoint {endpoint}"); + return false; + } + } + + true + }); + + let new_len = shard.len(); + drop(shard); + timer.observe(); + + // Do logging outside of the lock. + if clients_removed > 0 { + let size = self + .global_connections_count + .fetch_sub(clients_removed, atomic::Ordering::Relaxed) + - clients_removed; + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(clients_removed as i64); + info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}"); + } + let removed = current_len - new_len; + + if removed > 0 { + let global_pool_size = self + .global_pool_size + .fetch_sub(removed, atomic::Ordering::Relaxed) + - removed; + info!("pool: performed global pool gc. size now {global_pool_size}"); + } + } + + pub(crate) fn get_or_create_endpoint_pool( + self: &Arc, + endpoint: &EndpointCacheKey, + ) -> Arc>> { + // fast path + if let Some(pool) = self.global_pool.get(endpoint) { + return pool.clone(); + } + + // slow path + let new_pool = Arc::new(RwLock::new(EndpointConnPool { + pools: HashMap::new(), + total_conns: 0, + max_conns: self.config.pool_options.max_conns_per_endpoint, + _guard: Metrics::get().proxy.http_endpoint_pools.guard(), + global_connections_count: self.global_connections_count.clone(), + global_pool_size_max_conns: self.config.pool_options.max_total_conns, + })); + + // find or create a pool for this endpoint + let mut created = false; + let pool = self + .global_pool + .entry(endpoint.clone()) + .or_insert_with(|| { + created = true; + new_pool + }) + .clone(); + + // log new global pool size + if created { + let global_pool_size = self + .global_pool_size + .fetch_add(1, atomic::Ordering::Relaxed) + + 1; + info!( + "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}" + ); + } + + pool + } + + pub(crate) fn get( + self: &Arc, + ctx: &RequestMonitoring, + conn_info: &ConnInfo, + ) -> Result>, HttpConnError> { + let mut client: Option> = None; + let Some(endpoint) = conn_info.endpoint_cache_key() else { + return Ok(None); + }; + + let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint); + if let Some(entry) = endpoint_pool + .write() + .get_conn_entry(conn_info.db_and_user()) + { + client = Some(entry.conn); + } + let endpoint_pool = Arc::downgrade(&endpoint_pool); + + // ok return cached connection if found and establish a new one otherwise + if let Some(mut client) = client { + if client.is_closed() { + info!("pool: cached connection '{conn_info}' is closed, opening a new one"); + return Ok(None); + } + tracing::Span::current() + .record("conn_id", tracing::field::display(client.get_conn_id())); + tracing::Span::current().record( + "pid", + tracing::field::display(client.inner().get_process_id()), + ); + info!( + cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), + "pool: reusing connection '{conn_info}'" + ); + + client.session().send(ctx.session_id())?; + ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); + ctx.success(); + return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); + } + Ok(None) + } +} + +impl Client { + pub(crate) fn new( + inner: ClientInnerRemote, + conn_info: ConnInfo, + pool: Weak>>, + ) -> Self { + Self { + inner: Some(inner), + span: Span::current(), + conn_info, + pool, + } + } + + pub(crate) fn inner_mut(&mut self) -> (&mut C, Discard<'_, C>) { + let Self { + inner, + pool, + conn_info, + span: _, + } = self; + let inner = inner.as_mut().expect("client inner should not be removed"); + let inner_ref = inner.inner_mut(); + (inner_ref, Discard { conn_info, pool }) + } + + pub(crate) fn metrics(&self) -> Arc { + let aux = &self.inner.as_ref().unwrap().aux(); + USAGE_METRICS.register(Ids { + endpoint_id: aux.endpoint_id, + branch_id: aux.branch_id, + }) + } + + pub(crate) fn do_drop(&mut self) -> Option { + let conn_info = self.conn_info.clone(); + let client = self + .inner + .take() + .expect("client inner should not be removed"); + if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() { + let current_span = self.span.clone(); + // return connection to the pool + return Some(move || { + let _span = current_span.enter(); + EndpointConnPool::put(&conn_pool, &conn_info, client); + }); + } + None + } +} + +pub(crate) struct Client { + span: Span, + inner: Option>, + conn_info: ConnInfo, + pool: Weak>>, +} + +impl Drop for Client { + fn drop(&mut self) { + if let Some(drop) = self.do_drop() { + tokio::task::spawn_blocking(drop); + } + } +} + +impl Deref for Client { + type Target = C; + + fn deref(&self) -> &Self::Target { + self.inner + .as_ref() + .expect("client inner should not be removed") + .inner() + } +} + +pub(crate) trait ClientInnerExt: Sync + Send + 'static { + fn is_closed(&self) -> bool; + fn get_process_id(&self) -> i32; +} + +impl ClientInnerExt for tokio_postgres::Client { + fn is_closed(&self) -> bool { + self.is_closed() + } + + fn get_process_id(&self) -> i32 { + self.get_process_id() + } +} + +pub(crate) struct Discard<'a, C: ClientInnerExt> { + conn_info: &'a ConnInfo, + pool: &'a mut Weak>>, +} + +impl Discard<'_, C> { + pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { + let conn_info = &self.conn_info; + if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { + info!("pool: throwing away connection '{conn_info}' because connection is not idle"); + } + } + pub(crate) fn discard(&mut self) { + let conn_info = &self.conn_info; + if std::mem::take(self.pool).strong_count() > 0 { + info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); + } + } +} diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 9b6bc98557a5..79bb19328ffb 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -10,11 +10,12 @@ use rand::Rng; use tokio::net::TcpStream; use tracing::{debug, error, info, info_span, Instrument}; -use super::conn_pool::ConnInfo; use crate::context::RequestMonitoring; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; + +use super::conn_pool_lib::{ClientInnerExt, ConnInfo}; use crate::EndpointCacheKey; pub(crate) type Send = http2::SendRequest; @@ -22,15 +23,15 @@ pub(crate) type Connect = http2::Connection, hyper::body::Incoming, TokioExecutor>; #[derive(Clone)] -struct ConnPoolEntry { - conn: Send, +pub(crate) struct ConnPoolEntry { + conn: C, conn_id: uuid::Uuid, aux: MetricsAuxInfo, } // Per-endpoint connection pool // Number of open connections is limited by the `max_conns_per_endpoint`. -pub(crate) struct EndpointConnPool { +pub(crate) struct EndpointConnPool { // TODO(conrad): // either we should open more connections depending on stream count // (not exposed by hyper, need our own counter) @@ -40,13 +41,13 @@ pub(crate) struct EndpointConnPool { // seems somewhat redundant though. // // Probably we should run a semaphore and just the single conn. TBD. - conns: VecDeque, + conns: VecDeque>, _guard: HttpEndpointPoolsGuard<'static>, global_connections_count: Arc, } -impl EndpointConnPool { - fn get_conn_entry(&mut self) -> Option { +impl EndpointConnPool { + fn get_conn_entry(&mut self) -> Option> { let Self { conns, .. } = self; loop { @@ -81,7 +82,7 @@ impl EndpointConnPool { } } -impl Drop for EndpointConnPool { +impl Drop for EndpointConnPool { fn drop(&mut self) { if !self.conns.is_empty() { self.global_connections_count @@ -95,12 +96,12 @@ impl Drop for EndpointConnPool { } } -pub(crate) struct GlobalConnPool { +pub(crate) struct GlobalConnPool { // endpoint -> per-endpoint connection pool // // That should be a fairly conteded map, so return reference to the per-endpoint // pool as early as possible and release the lock. - global_pool: DashMap>>, + global_pool: DashMap>>>, /// Number of endpoint-connection pools /// @@ -115,7 +116,7 @@ pub(crate) struct GlobalConnPool { config: &'static crate::config::HttpConfig, } -impl GlobalConnPool { +impl GlobalConnPool { pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { let shards = config.pool_options.pool_shards; Arc::new(Self { @@ -210,7 +211,7 @@ impl GlobalConnPool { self: &Arc, ctx: &RequestMonitoring, conn_info: &ConnInfo, - ) -> Option { + ) -> Option> { let endpoint = conn_info.endpoint_cache_key()?; let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint); let client = endpoint_pool.write().get_conn_entry()?; @@ -228,7 +229,7 @@ impl GlobalConnPool { fn get_or_create_endpoint_pool( self: &Arc, endpoint: &EndpointCacheKey, - ) -> Arc> { + ) -> Arc>> { // fast path if let Some(pool) = self.global_pool.get(endpoint) { return pool.clone(); @@ -268,14 +269,14 @@ impl GlobalConnPool { } pub(crate) fn poll_http2_client( - global_pool: Arc, + global_pool: Arc>, ctx: &RequestMonitoring, conn_info: &ConnInfo, client: Send, connection: Connect, conn_id: uuid::Uuid, aux: MetricsAuxInfo, -) -> Client { +) -> Client { let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); let session_id = ctx.session_id(); @@ -322,13 +323,13 @@ pub(crate) fn poll_http2_client( Client::new(client, aux) } -pub(crate) struct Client { - pub(crate) inner: Send, +pub(crate) struct Client { + pub(crate) inner: C, aux: MetricsAuxInfo, } -impl Client { - pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self { +impl Client { + pub(self) fn new(inner: C, aux: MetricsAuxInfo) -> Self { Self { inner, aux } } @@ -339,3 +340,14 @@ impl Client { }) } } + +impl ClientInnerExt for Send { + fn is_closed(&self) -> bool { + self.is_closed() + } + + fn get_process_id(&self) -> i32 { + // ideally throw something meaningful + -1 + } +} diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index 5df37a8762ff..c4fdd00f7859 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -20,11 +20,12 @@ use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument, Span}; use super::backend::HttpConnError; -use super::conn_pool::{ClientInnerExt, ConnInfo}; +use super::conn_pool_lib::{ClientInnerExt, ConnInfo}; use crate::context::RequestMonitoring; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::Metrics; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; + use crate::{DbName, RoleName}; struct ConnPoolEntry { @@ -362,7 +363,7 @@ pub(crate) fn poll_client( LocalClient::new(inner, conn_info, pool_clone) } -struct ClientInner { +pub(crate) struct ClientInner { inner: C, session: tokio::sync::watch::Sender, cancel: CancellationToken, @@ -387,13 +388,24 @@ impl ClientInner { } } -impl LocalClient { - pub(crate) fn metrics(&self) -> Arc { - let aux = &self.inner.as_ref().unwrap().aux; - USAGE_METRICS.register(Ids { - endpoint_id: aux.endpoint_id, - branch_id: aux.branch_id, - }) +impl ClientInner { + pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> { + self.jti += 1; + let token = resign_jwt(&self.key, payload, self.jti)?; + + // initiates the auth session + self.inner.simple_query("discard all").await?; + self.inner + .query( + "select auth.jwt_session_init($1)", + &[&token as &(dyn ToSql + Sync)], + ) + .await?; + + let pid = self.inner.get_process_id(); + info!(pid, jti = self.jti, "user session state init"); + + Ok(()) } } @@ -422,6 +434,18 @@ impl LocalClient { pool, } } + + pub(crate) fn client_inner(&mut self) -> (&mut ClientInner, Discard<'_, C>) { + let Self { + inner, + pool, + conn_info, + span: _, + } = self; + let inner_m = inner.as_mut().expect("client inner should not be removed"); + (inner_m, Discard { conn_info, pool }) + } + pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) { let Self { inner, @@ -434,33 +458,6 @@ impl LocalClient { } } -impl LocalClient { - pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> { - let inner = self - .inner - .as_mut() - .expect("client inner should not be removed"); - - inner.jti += 1; - let token = resign_jwt(&inner.key, payload, inner.jti)?; - - // initiates the auth session - inner.inner.simple_query("discard all").await?; - inner - .inner - .query( - "select auth.jwt_session_init($1)", - &[&token as &(dyn ToSql + Sync)], - ) - .await?; - - let pid = inner.inner.get_process_id(); - info!(pid, jti = inner.jti, "user session state init"); - - Ok(()) - } -} - /// implements relatively efficient in-place json object key upserting /// /// only supports top-level keys @@ -524,24 +521,15 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String { jwt } -impl Discard<'_, C> { - pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { - let conn_info = &self.conn_info; - if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { - info!( - "local_pool: throwing away connection '{conn_info}' because connection is not idle" - ); - } - } - pub(crate) fn discard(&mut self) { - let conn_info = &self.conn_info; - if std::mem::take(self.pool).strong_count() > 0 { - info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); - } +impl LocalClient { + pub(crate) fn metrics(&self) -> Arc { + let aux = &self.inner.as_ref().unwrap().aux; + USAGE_METRICS.register(Ids { + endpoint_id: aux.endpoint_id, + branch_id: aux.branch_id, + }) } -} -impl LocalClient { fn do_drop(&mut self) -> Option { let conn_info = self.conn_info.clone(); let client = self @@ -568,6 +556,23 @@ impl Drop for LocalClient { } } +impl Discard<'_, C> { + pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { + let conn_info = &self.conn_info; + if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { + info!( + "local_pool: throwing away connection '{conn_info}' because connection is not idle" + ); + } + } + pub(crate) fn discard(&mut self) { + let conn_info = &self.conn_info; + if std::mem::take(self.pool).strong_count() > 0 { + info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); + } + } +} + #[cfg(test)] mod tests { use p256::ecdsa::SigningKey; diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 3ed3b6c845ce..29ff7b9d91c4 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -5,6 +5,7 @@ mod backend; pub mod cancel_set; mod conn_pool; +mod conn_pool_lib; mod http_conn_pool; mod http_util; mod json; @@ -20,7 +21,7 @@ use anyhow::Context; use async_trait::async_trait; use atomic_take::AtomicTake; use bytes::Bytes; -pub use conn_pool::GlobalConnPoolOptions; +pub use conn_pool_lib::GlobalConnPoolOptions; use futures::future::{select, Either}; use futures::TryFutureExt; use http::{Method, Response, StatusCode}; @@ -65,7 +66,7 @@ pub async fn task_main( } let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config); - let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config); + let conn_pool = conn_pool_lib::GlobalConnPool::new(&config.http_config); { let conn_pool = Arc::clone(&conn_pool); tokio::spawn(async move { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 3d8a2adef198..bb5eb390a6bc 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -25,10 +25,11 @@ use urlencoding; use utils::http::error::ApiError; use super::backend::{LocalProxyConnError, PoolingBackend}; -use super::conn_pool::{AuthData, ConnInfo, ConnInfoWithAuth}; +use super::conn_pool::{AuthData, ConnInfoWithAuth}; +use super::conn_pool_lib::{self, ConnInfo}; use super::http_util::json_response; use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError}; -use super::{conn_pool, local_conn_pool}; +use super::local_conn_pool; use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::auth::{endpoint_sni, ComputeUserInfoParseError}; use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig}; @@ -37,6 +38,7 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::{HttpDirection, Metrics}; use crate::proxy::{run_until_cancelled, NeonOptions}; use crate::serverless::backend::HttpConnError; + use crate::usage_metrics::{MetricCounter, MetricCounterRecorder}; use crate::{DbName, RoleName}; @@ -607,7 +609,8 @@ async fn handle_db_inner( let client = match keys.keys { ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => { let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?; - client.set_jwt_session(&payload).await?; + let (cli_inner, _dsc) = client.client_inner(); + cli_inner.set_jwt_session(&payload).await?; Client::Local(client) } _ => { @@ -1021,12 +1024,12 @@ async fn query_to_json( } enum Client { - Remote(conn_pool::Client), + Remote(conn_pool_lib::Client), Local(local_conn_pool::LocalClient), } enum Discard<'a> { - Remote(conn_pool::Discard<'a, tokio_postgres::Client>), + Remote(conn_pool_lib::Discard<'a, tokio_postgres::Client>), Local(local_conn_pool::Discard<'a, tokio_postgres::Client>), } @@ -1041,7 +1044,7 @@ impl Client { fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) { match self { Client::Remote(client) => { - let (c, d) = client.inner(); + let (c, d) = client.inner_mut(); (c, Discard::Remote(d)) } Client::Local(local_client) => {