From cff6927e4f58b1af6ecc2ee7279df1f2ff537295 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 20 Jun 2024 11:10:39 +0100 Subject: [PATCH] allow arbitrary config params with more efficient repr (#30) * switch to compressed config * remove some allocs and combine pw and authkeys into a single enum * allow arbitrary options * remove some redundant configs * clean up * more cleanup * replication * fix * fix lints * simplify * dont treat user separately * do not duplicate the encoding of params --- postgres-protocol/src/authentication/sasl.rs | 4 +- postgres-protocol/src/lib.rs | 2 +- postgres-protocol/src/message/frontend.rs | 60 ++++++++ postgres-types/src/lib.rs | 2 +- postgres-types/src/special.rs | 1 - postgres/src/config.rs | 24 ---- tokio-postgres/src/config.rs | 141 ++++++++----------- tokio-postgres/src/connect_raw.rs | 66 ++++----- tokio-postgres/tests/test/main.rs | 14 ++ tokio-postgres/tests/test/replication.rs | 1 + 10 files changed, 159 insertions(+), 156 deletions(-) diff --git a/postgres-protocol/src/authentication/sasl.rs b/postgres-protocol/src/authentication/sasl.rs index 19aa3c1e9..f2200a40c 100644 --- a/postgres-protocol/src/authentication/sasl.rs +++ b/postgres-protocol/src/authentication/sasl.rs @@ -117,7 +117,7 @@ enum Credentials { /// A regular password as a vector of bytes. Password(Vec), /// A precomputed pair of keys. - Keys(Box>), + Keys(ScramKeys), } enum State { @@ -176,7 +176,7 @@ impl ScramSha256 { /// Constructs a new instance which will use the provided key pair for authentication. pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 { - let password = Credentials::Keys(keys.into()); + let password = Credentials::Keys(keys); ScramSha256::new_inner(password, channel_binding, nonce()) } diff --git a/postgres-protocol/src/lib.rs b/postgres-protocol/src/lib.rs index 1f7aa7923..5f6ecf15f 100644 --- a/postgres-protocol/src/lib.rs +++ b/postgres-protocol/src/lib.rs @@ -68,7 +68,7 @@ macro_rules! from_usize { impl FromUsize for $t { #[inline] fn from_usize(x: usize) -> io::Result<$t> { - if x > <$t>::max_value() as usize { + if x > <$t>::MAX as usize { Err(io::Error::new( io::ErrorKind::InvalidInput, "value too large to transmit", diff --git a/postgres-protocol/src/message/frontend.rs b/postgres-protocol/src/message/frontend.rs index 5d0a8ff8c..dabed0bab 100644 --- a/postgres-protocol/src/message/frontend.rs +++ b/postgres-protocol/src/message/frontend.rs @@ -271,6 +271,66 @@ where }) } +#[inline] +pub fn startup_message_cstr( + parameters: &StartupMessageParams, + buf: &mut BytesMut, +) -> io::Result<()> { + write_body(buf, |buf| { + // postgres protocol version 3.0(196608) in bigger-endian + buf.put_i32(0x00_03_00_00); + buf.put_slice(¶meters.params); + buf.put_u8(0); + Ok(()) + }) +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct StartupMessageParams { + params: BytesMut, +} + +impl StartupMessageParams { + /// Set parameter's value by its name. + pub fn insert(&mut self, name: &str, value: &str) -> Result<(), io::Error> { + if name.contains('\0') | value.contains('\0') { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "string contains embedded null", + )); + } + self.params.put(name.as_bytes()); + self.params.put(&b"\0"[..]); + self.params.put(value.as_bytes()); + self.params.put(&b"\0"[..]); + Ok(()) + } + + pub fn str_iter(&self) -> impl Iterator { + let params = + std::str::from_utf8(&self.params).expect("should be validated as utf8 already"); + StrParamsIter(params) + } + + /// Get parameter's value by its name. + pub fn get(&self, name: &str) -> Option<&str> { + self.str_iter().find_map(|(k, v)| (k == name).then_some(v)) + } +} + +struct StrParamsIter<'a>(&'a str); + +impl<'a> Iterator for StrParamsIter<'a> { + type Item = (&'a str, &'a str); + + fn next(&mut self) -> Option { + let (key, r) = self.0.split_once('\0')?; + let (value, r) = r.split_once('\0')?; + self.0 = r; + Some((key, value)) + } +} + #[inline] pub fn sync(buf: &mut BytesMut) { buf.put_u8(b'S'); diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index b10d298de..79dd92996 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -1172,7 +1172,7 @@ impl ToSql for IpAddr { } fn downcast(len: usize) -> Result> { - if len > i32::max_value() as usize { + if len > i32::MAX as usize { Err("value too large to transmit".into()) } else { Ok(len as i32) diff --git a/postgres-types/src/special.rs b/postgres-types/src/special.rs index 1a865287e..d8541bf0e 100644 --- a/postgres-types/src/special.rs +++ b/postgres-types/src/special.rs @@ -1,7 +1,6 @@ use bytes::BytesMut; use postgres_protocol::types; use std::error::Error; -use std::{i32, i64}; use crate::{FromSql, IsNull, ToSql, Type}; diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 44e4bec3a..ccbbe7c51 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -145,12 +145,6 @@ impl Config { self } - /// Gets the password to authenticate with, if one has been configured with - /// the `password` method. - pub fn get_password(&self) -> Option<&[u8]> { - self.config.get_password() - } - /// Sets precomputed protocol-specific keys to authenticate with. /// When set, this option will override `password`. /// See [`AuthKeys`] for more information. @@ -159,12 +153,6 @@ impl Config { self } - /// Gets precomputed protocol-specific keys to authenticate with. - /// if one has been configured with the `auth_keys` method. - pub fn get_auth_keys(&self) -> Option { - self.config.get_auth_keys() - } - /// Sets the name of the database to connect to. /// /// Defaults to the user. @@ -185,24 +173,12 @@ impl Config { self } - /// Gets the command line options used to configure the server, if the - /// options have been set with the `options` method. - pub fn get_options(&self) -> Option<&str> { - self.config.get_options() - } - /// Sets the value of the `application_name` runtime parameter. pub fn application_name(&mut self, application_name: &str) -> &mut Config { self.config.application_name(application_name); self } - /// Gets the value of the `application_name` runtime parameter, if it has - /// been set with the `application_name` method. - pub fn get_application_name(&self) -> Option<&str> { - self.config.get_application_name() - } - /// Sets the SSL configuration. /// /// Defaults to `prefer`. diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index fdb5e6359..f6cff7bb0 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -10,6 +10,7 @@ use crate::tls::TlsConnect; #[cfg(feature = "runtime")] use crate::Socket; use crate::{Client, Connection, Error}; +use postgres_protocol::message::frontend::StartupMessageParams; use std::borrow::Cow; #[cfg(unix)] use std::ffi::OsStr; @@ -170,12 +171,7 @@ pub enum AuthKeys { /// ``` #[derive(Clone, PartialEq, Eq)] pub struct Config { - pub(crate) user: Option, - pub(crate) password: Option>, - pub(crate) auth_keys: Option>, - pub(crate) dbname: Option, - pub(crate) options: Option, - pub(crate) application_name: Option, + pub(crate) auth: Option, pub(crate) ssl_mode: SslMode, pub(crate) host: Vec, pub(crate) port: Vec, @@ -184,8 +180,18 @@ pub struct Config { pub(crate) keepalive_config: KeepaliveConfig, pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, - pub(crate) replication_mode: Option, pub(crate) max_backend_message_size: Option, + pub(crate) server_settings: StartupMessageParams, +} + +#[derive(Clone, PartialEq, Eq)] +#[non_exhaustive] +/// What auth info to use when authenticating +pub enum Auth { + /// password based auth + Password(Vec), + /// precomputed scram based auth + AuthKeys(AuthKeys), } impl Default for Config { @@ -203,12 +209,7 @@ impl Config { retries: None, }; Config { - user: None, - password: None, - auth_keys: None, - dbname: None, - options: None, - application_name: None, + auth: None, ssl_mode: SslMode::Prefer, host: vec![], port: vec![], @@ -217,8 +218,8 @@ impl Config { keepalive_config, target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, - replication_mode: None, max_backend_message_size: None, + server_settings: StartupMessageParams::default(), } } @@ -226,14 +227,14 @@ impl Config { /// /// Required. pub fn user(&mut self, user: &str) -> &mut Config { - self.user = Some(user.to_string()); + self.server_settings.insert("user", user).unwrap(); self } /// Gets the user to authenticate with, if one has been configured with /// the `user` method. pub fn get_user(&self) -> Option<&str> { - self.user.as_deref() + self.server_settings.get("user") } /// Sets the password to authenticate with. @@ -241,68 +242,60 @@ impl Config { where T: AsRef<[u8]>, { - self.password = Some(password.as_ref().to_vec()); + self.auth = Some(Auth::Password(password.as_ref().to_vec())); self } /// Gets the password to authenticate with, if one has been configured with /// the `password` method. - pub fn get_password(&self) -> Option<&[u8]> { - self.password.as_deref() + pub fn get_auth(&self) -> Option { + self.auth.clone() } /// Sets precomputed protocol-specific keys to authenticate with. /// When set, this option will override `password`. /// See [`AuthKeys`] for more information. - pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { - self.auth_keys = Some(Box::new(keys)); + pub fn auth(&mut self, keys: Auth) -> &mut Config { + self.auth = Some(keys); self } - /// Gets precomputed protocol-specific keys to authenticate with. - /// if one has been configured with the `auth_keys` method. - pub fn get_auth_keys(&self) -> Option { - self.auth_keys.as_deref().copied() + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.auth = Some(Auth::AuthKeys(keys)); + self } /// Sets the name of the database to connect to. /// /// Defaults to the user. pub fn dbname(&mut self, dbname: &str) -> &mut Config { - self.dbname = Some(dbname.to_string()); + self.server_settings.insert("database", dbname).unwrap(); self } /// Gets the name of the database to connect to, if one has been configured /// with the `dbname` method. pub fn get_dbname(&self) -> Option<&str> { - self.dbname.as_deref() + self.server_settings.get("database") } /// Sets command line options used to configure the server. pub fn options(&mut self, options: &str) -> &mut Config { - self.options = Some(options.to_string()); + self.server_settings.insert("options", options).unwrap(); self } - /// Gets the command line options used to configure the server, if the - /// options have been set with the `options` method. - pub fn get_options(&self) -> Option<&str> { - self.options.as_deref() - } - /// Sets the value of the `application_name` runtime parameter. pub fn application_name(&mut self, application_name: &str) -> &mut Config { - self.application_name = Some(application_name.to_string()); + self.server_settings + .insert("application_name", application_name) + .unwrap(); self } - /// Gets the value of the `application_name` runtime parameter, if it has - /// been set with the `application_name` method. - pub fn get_application_name(&self) -> Option<&str> { - self.application_name.as_deref() - } - /// Sets the SSL configuration. /// /// Defaults to `prefer`. @@ -465,15 +458,18 @@ impl Config { /// Set replication mode. pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { - self.replication_mode = Some(replication_mode); + match replication_mode { + ReplicationMode::Physical => { + self.server_settings.insert("replication", "true").unwrap() + } + ReplicationMode::Logical => self + .server_settings + .insert("replication", "database") + .unwrap(), + } self } - /// Get replication mode. - pub fn get_replication_mode(&self) -> Option { - self.replication_mode - } - /// Set limit for backend messages size. pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { self.max_backend_message_size = Some(max_backend_message_size); @@ -485,7 +481,8 @@ impl Config { self.max_backend_message_size } - fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { + /// Set an arbitrary param + pub fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { self.user(value); @@ -496,12 +493,6 @@ impl Config { "dbname" => { self.dbname(value); } - "options" => { - self.options(value); - } - "application_name" => { - self.application_name(value); - } "sslmode" => { let mode = match value { "disable" => SslMode::Disable, @@ -588,17 +579,6 @@ impl Config { }; self.channel_binding(channel_binding); } - "replication" => { - let mode = match value { - "off" => None, - "true" => Some(ReplicationMode::Physical), - "database" => Some(ReplicationMode::Logical), - _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))), - }; - if let Some(mode) = mode { - self.replication_mode(mode); - } - } "max_backend_message_size" => { let limit = value.parse::().map_err(|_| { Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) @@ -608,9 +588,9 @@ impl Config { } } key => { - return Err(Error::config_parse(Box::new(UnknownOption( - key.to_string(), - )))); + self.server_settings + .insert(key, value) + .map_err(|e| Error::config_parse(e.into()))?; } } @@ -665,12 +645,8 @@ impl fmt::Debug for Config { } } - f.debug_struct("Config") - .field("user", &self.user) - .field("password", &self.password.as_ref().map(|_| Redaction {})) - .field("dbname", &self.dbname) - .field("options", &self.options) - .field("application_name", &self.application_name) + let mut f = f.debug_struct("Config"); + f.field("auth", &self.auth.as_ref().map(|_| Redaction {})) .field("ssl_mode", &self.ssl_mode) .field("host", &self.host) .field("port", &self.port) @@ -680,23 +656,16 @@ impl fmt::Debug for Config { .field("keepalives_interval", &self.keepalive_config.interval) .field("keepalives_retries", &self.keepalive_config.retries) .field("target_session_attrs", &self.target_session_attrs) - .field("channel_binding", &self.channel_binding) - .field("replication", &self.replication_mode) - .finish() - } -} + .field("channel_binding", &self.channel_binding); -#[derive(Debug)] -struct UnknownOption(String); + for (k, v) in self.server_settings.str_iter() { + f.field(k, &v); + } -impl fmt::Display for UnknownOption { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "unknown option `{}`", self.0) + f.finish() } } -impl error::Error for UnknownOption {} - #[derive(Debug)] struct InvalidValue(&'static str); diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 8e788984a..4d3f58b78 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::{self, AuthKeys, Config, ReplicationMode}; +use crate::config::{self, Auth, AuthKeys, Config}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; @@ -116,28 +116,14 @@ where S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, { - let mut params = vec![("client_encoding", "UTF8")]; - if let Some(user) = &config.user { - params.push(("user", &**user)); - } - if let Some(dbname) = &config.dbname { - params.push(("database", &**dbname)); - } - if let Some(options) = &config.options { - params.push(("options", &**options)); - } - if let Some(application_name) = &config.application_name { - params.push(("application_name", &**application_name)); - } - if let Some(replication_mode) = &config.replication_mode { - match replication_mode { - ReplicationMode::Physical => params.push(("replication", "true")), - ReplicationMode::Logical => params.push(("replication", "database")), - } - } + // leave for user to provide: + // let mut params = config.server_settings.clone(); + // params + // .insert("client_encoding", "UTF8") + // .map_err(Error::encode)?; let mut buf = BytesMut::new(); - frontend::startup_message(params, &mut buf).map_err(Error::encode)?; + frontend::startup_message_cstr(&config.server_settings, &mut buf).map_err(Error::encode)?; stream .send(FrontendMessage::Raw(buf.freeze())) @@ -158,27 +144,25 @@ where Some(Message::AuthenticationCleartextPassword) => { can_skip_channel_binding(config)?; - let pass = config - .password - .as_ref() - .ok_or_else(|| Error::config("password missing".into()))?; - - authenticate_password(stream, pass).await?; + match &config.auth { + Some(Auth::Password(pass)) => authenticate_password(stream, pass).await?, + _ => return Err(Error::config("password missing".into())), + } } Some(Message::AuthenticationMd5Password(body)) => { can_skip_channel_binding(config)?; let user = config - .user - .as_ref() + .get_user() .ok_or_else(|| Error::config("user missing".into()))?; - let pass = config - .password - .as_ref() - .ok_or_else(|| Error::config("password missing".into()))?; - let output = authentication::md5_hash(user.as_bytes(), pass, body.salt()); - authenticate_password(stream, output.as_bytes()).await?; + match &config.auth { + Some(Auth::Password(pass)) => { + let output = authentication::md5_hash(user.as_bytes(), pass, body.salt()); + authenticate_password(stream, output.as_bytes()).await?; + } + _ => return Err(Error::config("password missing".into())), + } } Some(Message::AuthenticationSasl(body)) => { authenticate_sasl(stream, body, config).await?; @@ -276,12 +260,12 @@ where can_skip_channel_binding(config)?; } - let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() { - ScramSha256::new_with_keys(keys, channel_binding) - } else if let Some(password) = config.get_password() { - ScramSha256::new(password, channel_binding) - } else { - return Err(Error::config("password or auth keys missing".into())); + let mut scram = match &config.auth { + Some(Auth::AuthKeys(AuthKeys::ScramSha256(keys))) => { + ScramSha256::new_with_keys(*keys, channel_binding) + } + Some(Auth::Password(password)) => ScramSha256::new(password, channel_binding), + None => return Err(Error::config("password or auth keys missing".into())), }; let mut buf = BytesMut::new(); diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 772612de6..c074bb0d1 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -166,6 +166,20 @@ async fn pipelined_prepare() { assert_eq!(statement2.columns()[0].type_(), &Type::INT8); } +// regression: https://github.com/neondatabase/neon/issues/1287#issuecomment-1251922486 +#[tokio::test] +#[cfg(feature = "with-serde_json-1")] +async fn custom_params() { + let client = connect("user=postgres IntervalStyle=iso_8601").await; + + let row = client + .query_one("select to_json('0 seconds'::interval)", &[]) + .await + .unwrap(); + + assert_eq!(row.get::<_, serde_json_1::Value>(0), "PT0S"); +} + #[tokio::test] async fn insert_select() { let client = connect("user=postgres").await; diff --git a/tokio-postgres/tests/test/replication.rs b/tokio-postgres/tests/test/replication.rs index c176a4104..b510d8879 100644 --- a/tokio-postgres/tests/test/replication.rs +++ b/tokio-postgres/tests/test/replication.rs @@ -10,6 +10,7 @@ use tokio_postgres::NoTls; use tokio_postgres::SimpleQueryMessage::Row; #[tokio::test] +#[ignore = "replication"] async fn test_replication() { // form SQL connection let conninfo = "host=127.0.0.1 port=5433 user=postgres replication=database";