From c5b1d750575c3fde9e3e36faed400fffb6c20bfa Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Tue, 23 May 2023 11:32:41 +0300 Subject: [PATCH 01/10] Add text protocol based query method (#14) Add query_raw_txt client method It takes all the extended protocol params as text and passes them to postgres to sort out types. With that we can avoid situations when postgres derived different type compared to what was passed in arguments. There is also propare_typed method, but since we receive data in text format anyway it makes more sense to avoid dealing with types in params. This way we also can save on roundtrip and send Parse+Bind+Describe+Execute right away without waiting for params description before Bind. Use text protocol for responses -- that allows to grab postgres-provided serializations for types. Catch command tag. Expose row buffer size and add `max_backend_message_size` option to prevent handling and storing in memory large messages from the backend. Co-authored-by: Arthur Petukhovsky --- .github/workflows/ci.yml | 4 +- postgres-types/src/lib.rs | 18 ++++++- tokio-postgres/src/client.rs | 85 ++++++++++++++++++++++++++++++- tokio-postgres/src/codec.rs | 13 ++++- tokio-postgres/src/config.rs | 23 +++++++++ tokio-postgres/src/connect_raw.rs | 7 ++- tokio-postgres/src/prepare.rs | 2 +- tokio-postgres/src/query.rs | 26 ++++++++++ tokio-postgres/src/row.rs | 22 ++++++++ tokio-postgres/src/statement.rs | 23 +++++++++ tokio-postgres/tests/test/main.rs | 72 ++++++++++++++++++++++++++ 11 files changed, 289 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 008158fb0..549340329 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,7 +53,9 @@ jobs: steps: - uses: actions/checkout@v3 - uses: sfackler/actions/rustup@master - - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT + with: + version: 1.65.0 + - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - run: rustup target add wasm32-unknown-unknown - uses: actions/cache@v3 diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 52b5c773a..531a9f719 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -442,6 +442,22 @@ impl WrongType { } } +/// An error indicating that a as_text conversion was attempted on a binary +/// result. +#[derive(Debug)] +pub struct WrongFormat {} + +impl Error for WrongFormat {} + +impl fmt::Display for WrongFormat { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot read column as text while it is in binary format" + ) + } +} + /// A trait for types that can be created from a Postgres value. /// /// # Types @@ -893,7 +909,7 @@ pub trait ToSql: fmt::Debug { /// Supported Postgres message format types /// /// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum Format { /// Text format (UTF-8) Text, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 427a05049..a64ad5d9d 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -4,8 +4,10 @@ use crate::connection::{Request, RequestMessages}; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; +use crate::prepare::get_type; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; +use crate::statement::Column; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; @@ -16,7 +18,7 @@ use crate::{ copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, }; -use bytes::{Buf, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use fallible_iterator::FallibleIterator; use futures_channel::mpsc; use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; @@ -368,6 +370,87 @@ impl Client { query::query(&self.inner, statement, params).await } + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and + /// to save a roundtrip + pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let params = params.into_iter(); + let params_len = params.len(); + + let buf = self.inner.with_buf(|buf| { + // Parse, anonymous portal + frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // empty string selects the unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Describe portal to typecast results + frontend::describe(b'P', "", buf).map_err(Error::encode)?; + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + let mut responses = self + .inner + .send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + // now read the responses + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + // construct statement object + + let parameters = vec![Type::UNKNOWN; params_len]; + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(&self.inner, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_); + columns.push(column); + } + } + + let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns); + + Ok(RowStream::new(statement, responses)) + } + /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list diff --git a/tokio-postgres/src/codec.rs b/tokio-postgres/src/codec.rs index 9d078044b..23c371542 100644 --- a/tokio-postgres/src/codec.rs +++ b/tokio-postgres/src/codec.rs @@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages { } } -pub struct PostgresCodec; +pub struct PostgresCodec { + pub max_message_size: Option, +} impl Encoder for PostgresCodec { type Error = io::Error; @@ -64,6 +66,15 @@ impl Decoder for PostgresCodec { break; } + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + match header.tag() { backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index b178eac80..9614f19f9 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -207,6 +207,8 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) load_balance_hosts: LoadBalanceHosts, + pub(crate) replication_mode: Option, + pub(crate) max_backend_message_size: Option, } impl Default for Config { @@ -240,6 +242,8 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, load_balance_hosts: LoadBalanceHosts::Disable, + replication_mode: None, + max_backend_message_size: None, } } @@ -520,6 +524,17 @@ impl Config { self.load_balance_hosts } + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option { + self.max_backend_message_size + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -655,6 +670,14 @@ impl Config { }; self.load_balance_hosts(load_balance_hosts); } + "max_backend_message_size" => { + let limit = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 19be9eb01..7124557de 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -92,7 +92,12 @@ where let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?; let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec), + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), buf: BackendMessages::empty(), delayed: VecDeque::new(), }; diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index e3f09a7c2..ba8d5a43e 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -126,7 +126,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -async fn get_type(client: &Arc, oid: Oid) -> Result { +pub async fn get_type(client: &Arc, oid: Oid) -> Result { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index e6e1d00a8..ddc5dd27c 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -53,6 +53,7 @@ where statement, responses, rows_affected: None, + command_tag: None, _p: PhantomPinned, }) } @@ -74,6 +75,7 @@ pub async fn query_portal( statement: portal.statement().clone(), responses, rows_affected: None, + command_tag: None, _p: PhantomPinned, }) } @@ -208,11 +210,24 @@ pin_project! { statement: Statement, responses: Responses, rows_affected: Option, + command_tag: Option, #[pin] _p: PhantomPinned, } } +impl RowStream { + /// Creates a new `RowStream`. + pub fn new(statement: Statement, responses: Responses) -> Self { + RowStream { + statement, + responses, + command_tag: None, + _p: PhantomPinned, + } + } +} + impl Stream for RowStream { type Item = Result; @@ -225,6 +240,10 @@ impl Stream for RowStream { } Message::CommandComplete(body) => { *this.rows_affected = Some(extract_row_affected(&body)?); + + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } } Message::EmptyQueryResponse | Message::PortalSuspended => {} Message::ReadyForQuery(_) => return Poll::Ready(None), @@ -241,4 +260,11 @@ impl RowStream { pub fn rows_affected(&self) -> Option { self.rows_affected } + + /// Returns the command tag of this query. + /// + /// This is only available after the stream has been exhausted. + pub fn command_tag(&self) -> Option { + self.command_tag.clone() + } } diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index db179b432..ce4efed7e 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType}; use crate::{Error, Statement}; use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::DataRowBody; +use postgres_types::{Format, WrongFormat}; use std::fmt; use std::ops::Range; use std::str; @@ -187,6 +188,27 @@ impl Row { let range = self.ranges[idx].to_owned()?; Some(&self.body.buffer()[range]) } + + /// Interpret the column at the given index as text + /// + /// Useful when using query_raw_txt() which sets text transfer mode + pub fn as_text(&self, idx: usize) -> Result, Error> { + if self.statement.output_format() == Format::Text { + match self.col_buffer(idx) { + Some(raw) => { + FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) + } + None => Ok(None), + } + } else { + Err(Error::from_sql(Box::new(WrongFormat {}), idx)) + } + } + + /// Row byte size + pub fn body_len(&self) -> usize { + self.body.buffer().len() + } } impl AsName for SimpleColumn { diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 97561a8e4..b7ab11866 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -3,6 +3,7 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::Type; use postgres_protocol::message::frontend; +use postgres_types::Format; use std::{ fmt, sync::{Arc, Weak}, @@ -13,6 +14,7 @@ struct StatementInner { name: String, params: Vec, columns: Vec, + output_format: Format, } impl Drop for StatementInner { @@ -46,6 +48,22 @@ impl Statement { name, params, columns, + output_format: Format::Binary, + })) + } + + pub(crate) fn new_text( + inner: &Arc, + name: String, + params: Vec, + columns: Vec, + ) -> Statement { + Statement(Arc::new(StatementInner { + client: Arc::downgrade(inner), + name, + params, + columns, + output_format: Format::Text, })) } @@ -62,6 +80,11 @@ impl Statement { pub fn columns(&self) -> &[Column] { &self.0.columns } + + /// Returns output format for the statement. + pub fn output_format(&self) -> Format { + self.0.output_format + } } /// Information about a column of a query. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..3ef38f01a 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -249,6 +249,78 @@ async fn custom_array() { } } +#[tokio::test] +async fn query_raw_txt() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt("SELECT 55 * $1", ["42"]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + let res: i32 = rows[0].as_text(0).unwrap().unwrap().parse::().unwrap(); + assert_eq!(res, 55 * 42); + + let rows: Vec = client + .query_raw_txt("SELECT $1", ["42"]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "42"); + assert!(rows[0].body_len() > 0); +} + +#[tokio::test] +async fn limit_max_backend_message_size() { + let client = connect("user=postgres max_backend_message_size=10000").await; + let small: Vec = client + .query_raw_txt("SELECT REPEAT('a', 20)", []) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(small.len(), 1); + assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); + + let large: Result, Error> = client + .query_raw_txt("SELECT REPEAT('a', 2000000)", []) + .await + .unwrap() + .try_collect() + .await; + + assert!(large.is_err()); +} + +#[tokio::test] +async fn command_tag() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("select unnest('{1,2,3}'::int[]);", []) + .await + .unwrap(); + + pin_mut!(row_stream); + + let mut rows: Vec = Vec::new(); + while let Some(row) = row_stream.next().await { + rows.push(row.unwrap()); + } + + assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); +} + #[tokio::test] async fn custom_composite() { let client = connect("user=postgres").await; From feedb8adc53ddcfe326db0942470eeb3242ec209 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Thu, 8 Jun 2023 10:58:15 +0300 Subject: [PATCH 02/10] Allow passing null params in query_raw_txt() Previous coding only allowed passing vector of text values as params, but that does not allow to distinguish between nulls and 4-byte strings with "null" written in them. Change query_raw_txt params argument to accept Vec> instead. --- tokio-postgres/src/client.rs | 11 ++++++---- tokio-postgres/tests/test/main.rs | 34 +++++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index a64ad5d9d..8912de774 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -375,7 +375,7 @@ impl Client { pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result where S: AsRef, - I: IntoIterator, + I: IntoIterator>, I::IntoIter: ExactSizeIterator, { let params = params.into_iter(); @@ -390,9 +390,12 @@ impl Client { "", // empty string selects the unnamed prepared statement std::iter::empty(), // all parameters use the default format (text) params, - |param, buf| { - buf.put_slice(param.as_ref().as_bytes()); - Ok(postgres_protocol::IsNull::No) + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + } + None => Ok(postgres_protocol::IsNull::Yes), }, Some(0), // all text buf, diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 3ef38f01a..4adec65bb 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -254,7 +254,7 @@ async fn query_raw_txt() { let client = connect("user=postgres").await; let rows: Vec = client - .query_raw_txt("SELECT 55 * $1", ["42"]) + .query_raw_txt("SELECT 55 * $1", [Some("42")]) .await .unwrap() .try_collect() @@ -266,7 +266,7 @@ async fn query_raw_txt() { assert_eq!(res, 55 * 42); let rows: Vec = client - .query_raw_txt("SELECT $1", ["42"]) + .query_raw_txt("SELECT $1", [Some("42")]) .await .unwrap() .try_collect() @@ -278,6 +278,36 @@ async fn query_raw_txt() { assert!(rows[0].body_len() > 0); } +#[tokio::test] +async fn query_raw_txt_nulls() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt( + "SELECT $1 as str, $2 as n, 'null' as str2, null as n2", + [Some("null"), None], + ) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + + let res = rows[0].as_text(0).unwrap(); + assert_eq!(res, Some("null")); + + let res = rows[0].as_text(1).unwrap(); + assert_eq!(res, None); + + let res = rows[0].as_text(2).unwrap(); + assert_eq!(res, Some("null")); + + let res = rows[0].as_text(3).unwrap(); + assert_eq!(res, None); +} + #[tokio::test] async fn limit_max_backend_message_size() { let client = connect("user=postgres max_backend_message_size=10000").await; From b6921685ca6cbdfcdad3ea84b8ae3c53ce032d47 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Tue, 13 Jun 2023 01:11:08 +0300 Subject: [PATCH 03/10] Return more RowDescription fields As we are trying to match client-side behaviour with node-postgres we need to return this fields as well because node-postgres returns them. --- tokio-postgres/src/client.rs | 4 ++- tokio-postgres/src/prepare.rs | 2 +- tokio-postgres/src/statement.rs | 58 +++++++++++++++++++++++++++++-- tokio-postgres/tests/test/main.rs | 25 +++++++++++++ 4 files changed, 84 insertions(+), 5 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 8912de774..d48fc7fcb 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -443,8 +443,10 @@ impl Client { if let Some(row_description) = row_description { let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { + // NB: for some types that function may send a query to the server. At least in + // raw text mode we don't need that info and can skip this. let type_ = get_type(&self.inner, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_); + let column = Column::new(field.name().to_string(), type_, field); columns.push(column); } } diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index ba8d5a43e..0abb8e453 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -95,7 +95,7 @@ pub async fn prepare( let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { let type_ = get_type(client, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_); + let column = Column::new(field.name().to_string(), type_, field); columns.push(column); } } diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index b7ab11866..8743f00f0 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -2,7 +2,10 @@ use crate::client::InnerClient; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::Type; -use postgres_protocol::message::frontend; +use postgres_protocol::{ + message::{backend::Field, frontend}, + Oid, +}; use postgres_types::Format; use std::{ fmt, @@ -91,11 +94,30 @@ impl Statement { pub struct Column { name: String, type_: Type, + + // raw fields from RowDescription + table_oid: Oid, + column_id: i16, + format: i16, + + // that better be stored in self.type_, but that is more radical refactoring + type_oid: Oid, + type_size: i16, + type_modifier: i32, } impl Column { - pub(crate) fn new(name: String, type_: Type) -> Column { - Column { name, type_ } + pub(crate) fn new(name: String, type_: Type, raw_field: Field<'_>) -> Column { + Column { + name, + type_, + table_oid: raw_field.table_oid(), + column_id: raw_field.column_id(), + format: raw_field.format(), + type_oid: raw_field.type_oid(), + type_size: raw_field.type_size(), + type_modifier: raw_field.type_modifier(), + } } /// Returns the name of the column. @@ -107,6 +129,36 @@ impl Column { pub fn type_(&self) -> &Type { &self.type_ } + + /// Returns the table OID of the column. + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + /// Returns the column ID of the column. + pub fn column_id(&self) -> i16 { + self.column_id + } + + /// Returns the format of the column. + pub fn format(&self) -> i16 { + self.format + } + + /// Returns the type OID of the column. + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + /// Returns the type size of the column. + pub fn type_size(&self) -> i16 { + self.type_size + } + + /// Returns the type modifier of the column. + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } } impl fmt::Debug for Column { diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 4adec65bb..40ff0d7e5 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -351,6 +351,31 @@ async fn command_tag() { assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); } +#[tokio::test] +async fn column_extras() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt("select relacl, relname from pg_class limit 1", []) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let column = rows[0].columns().get(1).unwrap(); + assert_eq!(column.name(), "relname"); + assert_eq!(column.type_(), &Type::NAME); + + assert!(column.table_oid() > 0); + assert_eq!(column.column_id(), 2); + assert_eq!(column.format(), 0); + + assert_eq!(column.type_oid(), 19); + assert_eq!(column.type_size(), 64); + assert_eq!(column.type_modifier(), -1); +} + #[tokio::test] async fn custom_composite() { let client = connect("user=postgres").await; From 3b8d078f73fd550df7c1e59f29258d4b952891fa Mon Sep 17 00:00:00 2001 From: Alex Chi Z Date: Mon, 24 Jul 2023 15:15:14 -0400 Subject: [PATCH 04/10] add query_raw_txt for transaction (#20) Signed-off-by: Alex Chi --- tokio-postgres/src/client.rs | 85 ++----------------------- tokio-postgres/src/generic_client.rs | 25 ++++++++ tokio-postgres/src/query.rs | 94 +++++++++++++++++++++++++++- tokio-postgres/src/transaction.rs | 10 +++ 4 files changed, 131 insertions(+), 83 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index d48fc7fcb..ab0de2cef 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -4,10 +4,8 @@ use crate::connection::{Request, RequestMessages}; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; -use crate::prepare::get_type; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; -use crate::statement::Column; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; @@ -18,7 +16,7 @@ use crate::{ copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, }; -use bytes::{Buf, BufMut, BytesMut}; +use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; use futures_channel::mpsc; use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; @@ -374,86 +372,11 @@ impl Client { /// to save a roundtrip pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result where - S: AsRef, + S: AsRef + Sync + Send, I: IntoIterator>, - I::IntoIter: ExactSizeIterator, + I::IntoIter: ExactSizeIterator + Sync + Send, { - let params = params.into_iter(); - let params_len = params.len(); - - let buf = self.inner.with_buf(|buf| { - // Parse, anonymous portal - frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; - // Bind, pass params as text, retrieve as binary - match frontend::bind( - "", // empty string selects the unnamed portal - "", // empty string selects the unnamed prepared statement - std::iter::empty(), // all parameters use the default format (text) - params, - |param, buf| match param { - Some(param) => { - buf.put_slice(param.as_ref().as_bytes()); - Ok(postgres_protocol::IsNull::No) - } - None => Ok(postgres_protocol::IsNull::Yes), - }, - Some(0), // all text - buf, - ) { - Ok(()) => Ok(()), - Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), - Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), - }?; - - // Describe portal to typecast results - frontend::describe(b'P', "", buf).map_err(Error::encode)?; - // Execute - frontend::execute("", 0, buf).map_err(Error::encode)?; - // Sync - frontend::sync(buf); - - Ok(buf.split().freeze()) - })?; - - let mut responses = self - .inner - .send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - - // now read the responses - - match responses.next().await? { - Message::ParseComplete => {} - _ => return Err(Error::unexpected_message()), - } - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), - } - let row_description = match responses.next().await? { - Message::RowDescription(body) => Some(body), - Message::NoData => None, - _ => return Err(Error::unexpected_message()), - }; - - // construct statement object - - let parameters = vec![Type::UNKNOWN; params_len]; - - let mut columns = vec![]; - if let Some(row_description) = row_description { - let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(Error::parse)? { - // NB: for some types that function may send a query to the server. At least in - // raw text mode we don't need that info and can skip this. - let type_ = get_type(&self.inner, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_, field); - columns.push(column); - } - } - - let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns); - - Ok(RowStream::new(statement, responses)) + query::query_txt(&self.inner, query, params).await } /// Executes a statement, returning the number of rows modified. diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index 50cff9712..a259532e5 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -56,6 +56,13 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; + /// Like `Client::query_raw_txt`. + async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send; + /// Like `Client::prepare`. async fn prepare(&self, query: &str) -> Result; @@ -136,6 +143,15 @@ impl GenericClient for Client { self.query_raw(statement, params).await } + async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(query, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -222,6 +238,15 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } + async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(query, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index ddc5dd27c..b642a682c 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -1,17 +1,21 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; +use crate::prepare::get_type; use crate::types::{BorrowToSql, IsNull}; -use crate::{Error, Portal, Row, Statement}; -use bytes::{Bytes, BytesMut}; +use crate::{Column, Error, Portal, Row, Statement}; +use bytes::{BufMut, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::{CommandCompleteBody, Message}; use postgres_protocol::message::frontend; +use postgres_types::Type; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; struct BorrowToSqlParamsDebug<'a, T>(&'a [T]); @@ -58,6 +62,92 @@ where }) } +pub async fn query_txt( + client: &Arc, + query: S, + params: I, +) -> Result +where + S: AsRef + Sync + Send, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, +{ + let params = params.into_iter(); + let params_len = params.len(); + + let buf = client.with_buf(|buf| { + // Parse, anonymous portal + frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // empty string selects the unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + } + None => Ok(postgres_protocol::IsNull::Yes), + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Describe portal to typecast results + frontend::describe(b'P', "", buf).map_err(Error::encode)?; + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + // now read the responses + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + // construct statement object + + let parameters = vec![Type::UNKNOWN; params_len]; + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + // NB: for some types that function may send a query to the server. At least in + // raw text mode we don't need that info and can skip this. + let type_ = get_type(client, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + let statement = Statement::new_text(client, "".to_owned(), parameters, columns); + + Ok(RowStream::new(statement, responses)) +} + pub async fn query_portal( client: &InnerClient, portal: &Portal, diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 96a324652..806196aa3 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -149,6 +149,16 @@ impl<'a> Transaction<'a> { self.client.query_raw(statement, params).await } + /// Like `Client::query_raw_txt`. + pub async fn query_raw_txt(&self, query: S, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.client.query_raw_txt(query, params).await + } + /// Like `Client::execute`. pub async fn execute( &self, From c5c8c9f515be3d1de8306744387666c2e4f8032b Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 11 Aug 2023 15:14:05 +0100 Subject: [PATCH 05/10] Connection changes (#21) * refactor query_raw_txt to use a pre-prepared statement * expose ready_status on RowStream --- .github/workflows/ci.yml | 2 +- tokio-postgres/src/client.rs | 14 ++-- tokio-postgres/src/config.rs | 2 - tokio-postgres/src/generic_client.rs | 17 +++-- tokio-postgres/src/query.rs | 99 +++++++++++----------------- tokio-postgres/src/row.rs | 10 ++- tokio-postgres/src/statement.rs | 23 ------- tokio-postgres/src/transaction.rs | 9 +-- tokio-postgres/tests/test/main.rs | 36 ++++++++-- 9 files changed, 105 insertions(+), 107 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 549340329..431e17748 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,7 +54,7 @@ jobs: - uses: actions/checkout@v3 - uses: sfackler/actions/rustup@master with: - version: 1.65.0 + version: 1.67.0 - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - run: rustup target add wasm32-unknown-unknown diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index ab0de2cef..e67553101 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -370,13 +370,19 @@ impl Client { /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip - pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + pub async fn query_raw_txt<'a, T, S, I>( + &self, + statement: &T, + params: I, + ) -> Result where - S: AsRef + Sync + Send, + T: ?Sized + ToStatement, + S: AsRef, I: IntoIterator>, - I::IntoIter: ExactSizeIterator + Sync + Send, + I::IntoIter: ExactSizeIterator, { - query::query_txt(&self.inner, query, params).await + let statement = statement.__convert().into_statement(self).await?; + query::query_txt(&self.inner, statement, params).await } /// Executes a statement, returning the number of rows modified. diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 9614f19f9..2547469ec 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -207,7 +207,6 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) load_balance_hosts: LoadBalanceHosts, - pub(crate) replication_mode: Option, pub(crate) max_backend_message_size: Option, } @@ -242,7 +241,6 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, load_balance_hosts: LoadBalanceHosts::Disable, - replication_mode: None, max_backend_message_size: None, } } diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index a259532e5..a4ee4808b 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -57,8 +57,13 @@ pub trait GenericClient: private::Sealed { I::IntoIter: ExactSizeIterator; /// Like `Client::query_raw_txt`. - async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + async fn query_raw_txt<'a, T, S, I>( + &self, + statement: &T, + params: I, + ) -> Result where + T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send; @@ -143,13 +148,14 @@ impl GenericClient for Client { self.query_raw(statement, params).await } - async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result where + T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send, { - self.query_raw_txt(query, params).await + self.query_raw_txt(statement, params).await } async fn prepare(&self, query: &str) -> Result { @@ -238,13 +244,14 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } - async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result where + T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send, { - self.query_raw_txt(query, params).await + self.query_raw_txt(statement, params).await } async fn prepare(&self, query: &str) -> Result { diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index b642a682c..7cf9580e5 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -1,17 +1,15 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::prepare::get_type; use crate::types::{BorrowToSql, IsNull}; -use crate::{Column, Error, Portal, Row, Statement}; +use crate::{Error, Portal, Row, Statement}; use bytes::{BufMut, Bytes, BytesMut}; -use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::{CommandCompleteBody, Message}; use postgres_protocol::message::frontend; -use postgres_types::Type; +use postgres_types::Format; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; @@ -58,30 +56,29 @@ where responses, rows_affected: None, command_tag: None, + status: None, + output_format: Format::Binary, _p: PhantomPinned, }) } pub async fn query_txt( client: &Arc, - query: S, + statement: Statement, params: I, ) -> Result where - S: AsRef + Sync + Send, + S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, { let params = params.into_iter(); - let params_len = params.len(); let buf = client.with_buf(|buf| { - // Parse, anonymous portal - frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; // Bind, pass params as text, retrieve as binary match frontend::bind( "", // empty string selects the unnamed portal - "", // empty string selects the unnamed prepared statement + statement.name(), // named prepared statement std::iter::empty(), // all parameters use the default format (text) params, |param, buf| match param { @@ -99,8 +96,6 @@ where Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), }?; - // Describe portal to typecast results - frontend::describe(b'P', "", buf).map_err(Error::encode)?; // Execute frontend::execute("", 0, buf).map_err(Error::encode)?; // Sync @@ -109,43 +104,17 @@ where Ok(buf.split().freeze()) })?; - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - // now read the responses - - match responses.next().await? { - Message::ParseComplete => {} - _ => return Err(Error::unexpected_message()), - } - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), - } - let row_description = match responses.next().await? { - Message::RowDescription(body) => Some(body), - Message::NoData => None, - _ => return Err(Error::unexpected_message()), - }; - - // construct statement object - - let parameters = vec![Type::UNKNOWN; params_len]; - - let mut columns = vec![]; - if let Some(row_description) = row_description { - let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(Error::parse)? { - // NB: for some types that function may send a query to the server. At least in - // raw text mode we don't need that info and can skip this. - let type_ = get_type(client, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_, field); - columns.push(column); - } - } - - let statement = Statement::new_text(client, "".to_owned(), parameters, columns); - - Ok(RowStream::new(statement, responses)) + let responses = start(client, buf).await?; + Ok(RowStream { + statement, + responses, + command_tag: None, + status: None, + output_format: Format::Text, + _p: PhantomPinned, + rows_affected: None, + }) } pub async fn query_portal( @@ -166,6 +135,8 @@ pub async fn query_portal( responses, rows_affected: None, command_tag: None, + status: None, + output_format: Format::Binary, _p: PhantomPinned, }) } @@ -301,23 +272,13 @@ pin_project! { responses: Responses, rows_affected: Option, command_tag: Option, + output_format: Format, + status: Option, #[pin] _p: PhantomPinned, } } -impl RowStream { - /// Creates a new `RowStream`. - pub fn new(statement: Statement, responses: Responses) -> Self { - RowStream { - statement, - responses, - command_tag: None, - _p: PhantomPinned, - } - } -} - impl Stream for RowStream { type Item = Result; @@ -326,7 +287,11 @@ impl Stream for RowStream { loop { match ready!(this.responses.poll_next(cx)?) { Message::DataRow(body) => { - return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) + return Poll::Ready(Some(Ok(Row::new( + this.statement.clone(), + body, + *this.output_format, + )?))) } Message::CommandComplete(body) => { *this.rows_affected = Some(extract_row_affected(&body)?); @@ -336,7 +301,10 @@ impl Stream for RowStream { } } Message::EmptyQueryResponse | Message::PortalSuspended => {} - Message::ReadyForQuery(_) => return Poll::Ready(None), + Message::ReadyForQuery(status) => { + *this.status = Some(status.status()); + return Poll::Ready(None); + } _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } } @@ -357,4 +325,11 @@ impl RowStream { pub fn command_tag(&self) -> Option { self.command_tag.clone() } + + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> Option { + self.status + } } diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index ce4efed7e..754b5f28c 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -98,6 +98,7 @@ where /// A row of data returned from the database by a query. pub struct Row { statement: Statement, + output_format: Format, body: DataRowBody, ranges: Vec>>, } @@ -111,12 +112,17 @@ impl fmt::Debug for Row { } impl Row { - pub(crate) fn new(statement: Statement, body: DataRowBody) -> Result { + pub(crate) fn new( + statement: Statement, + body: DataRowBody, + output_format: Format, + ) -> Result { let ranges = body.ranges().collect().map_err(Error::parse)?; Ok(Row { statement, body, ranges, + output_format, }) } @@ -193,7 +199,7 @@ impl Row { /// /// Useful when using query_raw_txt() which sets text transfer mode pub fn as_text(&self, idx: usize) -> Result, Error> { - if self.statement.output_format() == Format::Text { + if self.output_format == Format::Text { match self.col_buffer(idx) { Some(raw) => { FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 8743f00f0..246d36a57 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -6,7 +6,6 @@ use postgres_protocol::{ message::{backend::Field, frontend}, Oid, }; -use postgres_types::Format; use std::{ fmt, sync::{Arc, Weak}, @@ -17,7 +16,6 @@ struct StatementInner { name: String, params: Vec, columns: Vec, - output_format: Format, } impl Drop for StatementInner { @@ -51,22 +49,6 @@ impl Statement { name, params, columns, - output_format: Format::Binary, - })) - } - - pub(crate) fn new_text( - inner: &Arc, - name: String, - params: Vec, - columns: Vec, - ) -> Statement { - Statement(Arc::new(StatementInner { - client: Arc::downgrade(inner), - name, - params, - columns, - output_format: Format::Text, })) } @@ -83,11 +65,6 @@ impl Statement { pub fn columns(&self) -> &[Column] { &self.0.columns } - - /// Returns output format for the statement. - pub fn output_format(&self) -> Format { - self.0.output_format - } } /// Information about a column of a query. diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 806196aa3..ca386974e 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -150,13 +150,14 @@ impl<'a> Transaction<'a> { } /// Like `Client::query_raw_txt`. - pub async fn query_raw_txt(&self, query: S, params: I) -> Result + pub async fn query_raw_txt(&self, statement: &T, params: I) -> Result where - S: AsRef + Sync + Send, + T: ?Sized + ToStatement, + S: AsRef, I: IntoIterator>, - I::IntoIter: ExactSizeIterator + Sync + Send, + I::IntoIter: ExactSizeIterator, { - self.client.query_raw_txt(query, params).await + self.client.query_raw_txt(statement, params).await } /// Like `Client::execute`. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 40ff0d7e5..565984271 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -312,7 +312,7 @@ async fn query_raw_txt_nulls() { async fn limit_max_backend_message_size() { let client = connect("user=postgres max_backend_message_size=10000").await; let small: Vec = client - .query_raw_txt("SELECT REPEAT('a', 20)", []) + .query_raw_txt("SELECT REPEAT('a', 20)", [] as [Option<&str>; 0]) .await .unwrap() .try_collect() @@ -323,7 +323,7 @@ async fn limit_max_backend_message_size() { assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); let large: Result, Error> = client - .query_raw_txt("SELECT REPEAT('a', 2000000)", []) + .query_raw_txt("SELECT REPEAT('a', 2000000)", [] as [Option<&str>; 0]) .await .unwrap() .try_collect() @@ -337,7 +337,7 @@ async fn command_tag() { let client = connect("user=postgres").await; let row_stream = client - .query_raw_txt("select unnest('{1,2,3}'::int[]);", []) + .query_raw_txt("select unnest('{1,2,3}'::int[]);", [] as [Option<&str>; 0]) .await .unwrap(); @@ -351,12 +351,40 @@ async fn command_tag() { assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); } +#[tokio::test] +async fn ready_for_query() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("START TRANSACTION", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), Some(b'T')); + + let row_stream = client + .query_raw_txt("ROLLBACK", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), Some(b'I')); +} + #[tokio::test] async fn column_extras() { let client = connect("user=postgres").await; let rows: Vec = client - .query_raw_txt("select relacl, relname from pg_class limit 1", []) + .query_raw_txt( + "select relacl, relname from pg_class limit 1", + [] as [Option<&str>; 0], + ) .await .unwrap() .try_collect() From a7731024191390654f89721df2396277b9f9cca1 Mon Sep 17 00:00:00 2001 From: Gus Caplan Date: Fri, 22 Sep 2023 16:41:07 -0700 Subject: [PATCH 06/10] support unnamed statements --- postgres-protocol/src/message/backend.rs | 23 +++++++- tokio-postgres/Cargo.toml | 2 +- tokio-postgres/src/bind.rs | 2 +- tokio-postgres/src/client.rs | 6 ++- tokio-postgres/src/connect.rs | 2 +- tokio-postgres/src/connect_raw.rs | 10 ++-- tokio-postgres/src/connection.rs | 3 +- tokio-postgres/src/copy_in.rs | 12 +++-- tokio-postgres/src/copy_out.rs | 10 ++-- tokio-postgres/src/error/mod.rs | 12 +++-- tokio-postgres/src/prepare.rs | 23 +++++--- tokio-postgres/src/query.rs | 13 +++-- tokio-postgres/src/simple_query.rs | 6 +-- tokio-postgres/src/statement.rs | 67 ++++++++++++++++++------ tokio-postgres/src/to_statement.rs | 2 +- 15 files changed, 140 insertions(+), 53 deletions(-) diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 1b5be1098..da267101c 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -72,6 +72,7 @@ impl Header { } /// An enum representing Postgres backend messages. +#[derive(Debug, PartialEq)] #[non_exhaustive] pub enum Message { AuthenticationCleartextPassword, @@ -333,6 +334,7 @@ impl Read for Buffer { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationMd5PasswordBody { salt: [u8; 4], } @@ -344,6 +346,7 @@ impl AuthenticationMd5PasswordBody { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationGssContinueBody(Bytes); impl AuthenticationGssContinueBody { @@ -353,6 +356,7 @@ impl AuthenticationGssContinueBody { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationSaslBody(Bytes); impl AuthenticationSaslBody { @@ -362,6 +366,7 @@ impl AuthenticationSaslBody { } } +#[derive(Debug, PartialEq)] pub struct SaslMechanisms<'a>(&'a [u8]); impl<'a> FallibleIterator for SaslMechanisms<'a> { @@ -387,6 +392,7 @@ impl<'a> FallibleIterator for SaslMechanisms<'a> { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationSaslContinueBody(Bytes); impl AuthenticationSaslContinueBody { @@ -396,6 +402,7 @@ impl AuthenticationSaslContinueBody { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationSaslFinalBody(Bytes); impl AuthenticationSaslFinalBody { @@ -405,6 +412,7 @@ impl AuthenticationSaslFinalBody { } } +#[derive(Debug, PartialEq)] pub struct BackendKeyDataBody { process_id: i32, secret_key: i32, @@ -422,6 +430,7 @@ impl BackendKeyDataBody { } } +#[derive(Debug, PartialEq)] pub struct CommandCompleteBody { tag: Bytes, } @@ -433,6 +442,7 @@ impl CommandCompleteBody { } } +#[derive(Debug, PartialEq)] pub struct CopyDataBody { storage: Bytes, } @@ -449,6 +459,7 @@ impl CopyDataBody { } } +#[derive(Debug, PartialEq)] pub struct CopyInResponseBody { format: u8, len: u16, @@ -470,6 +481,7 @@ impl CopyInResponseBody { } } +#[derive(Debug, PartialEq)] pub struct ColumnFormats<'a> { buf: &'a [u8], remaining: u16, @@ -503,6 +515,7 @@ impl<'a> FallibleIterator for ColumnFormats<'a> { } } +#[derive(Debug, PartialEq)] pub struct CopyOutResponseBody { format: u8, len: u16, @@ -524,7 +537,7 @@ impl CopyOutResponseBody { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct DataRowBody { storage: Bytes, len: u16, @@ -599,6 +612,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> { } } +#[derive(Debug, PartialEq)] pub struct ErrorResponseBody { storage: Bytes, } @@ -657,6 +671,7 @@ impl<'a> ErrorField<'a> { } } +#[derive(Debug, PartialEq)] pub struct NoticeResponseBody { storage: Bytes, } @@ -668,6 +683,7 @@ impl NoticeResponseBody { } } +#[derive(Debug, PartialEq)] pub struct NotificationResponseBody { process_id: i32, channel: Bytes, @@ -691,6 +707,7 @@ impl NotificationResponseBody { } } +#[derive(Debug, PartialEq)] pub struct ParameterDescriptionBody { storage: Bytes, len: u16, @@ -706,6 +723,7 @@ impl ParameterDescriptionBody { } } +#[derive(Debug, PartialEq)] pub struct Parameters<'a> { buf: &'a [u8], remaining: u16, @@ -739,6 +757,7 @@ impl<'a> FallibleIterator for Parameters<'a> { } } +#[derive(Debug, PartialEq)] pub struct ParameterStatusBody { name: Bytes, value: Bytes, @@ -756,6 +775,7 @@ impl ParameterStatusBody { } } +#[derive(Debug, PartialEq)] pub struct ReadyForQueryBody { status: u8, } @@ -767,6 +787,7 @@ impl ReadyForQueryBody { } } +#[derive(Debug, PartialEq)] pub struct RowDescriptionBody { storage: Bytes, len: u16, diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index ec5e3cbec..c11de2e2b 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -59,7 +59,7 @@ postgres-types = { version = "0.2.5", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } rand = "0.8.5" -whoami = "1.4.1" +whoami = "1.4" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] socket2 = { version = "0.5", features = ["all"] } diff --git a/tokio-postgres/src/bind.rs b/tokio-postgres/src/bind.rs index 9c5c49218..dac1a3c06 100644 --- a/tokio-postgres/src/bind.rs +++ b/tokio-postgres/src/bind.rs @@ -31,7 +31,7 @@ where match responses.next().await? { Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(Portal::new(client, name, statement)) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index e67553101..eb7b8bf0a 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -231,7 +231,11 @@ impl Client { query: &str, parameter_types: &[Type], ) -> Result { - prepare::prepare(&self.inner, query, parameter_types).await + prepare::prepare(&self.inner, query, parameter_types, false).await + } + + pub(crate) async fn prepare_unnamed(&self, query: &str) -> Result { + prepare::prepare(&self.inner, query, &[], true).await } /// Executes a statement, returning a vector of the resulting rows. diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index ca57b9cdd..e697e5bc6 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -195,7 +195,7 @@ where } } Some(_) => {} - None => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), } } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 7124557de..b468c5f32 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -195,14 +195,14 @@ where )) } Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), } match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationOk) => Ok(()), Some(Message::ErrorResponse(body)) => Err(Error::db(body)), - Some(_) => Err(Error::unexpected_message()), + Some(m) => Err(Error::unexpected_message(m)), None => Err(Error::closed()), } } @@ -296,7 +296,7 @@ where let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslContinue(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), }; @@ -314,7 +314,7 @@ where let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslFinal(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), }; @@ -353,7 +353,7 @@ where } Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)), Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), } } diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 414335955..652038cc0 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -139,7 +139,8 @@ where Some(response) => response, None => match messages.next().map_err(Error::parse)? { Some(Message::ErrorResponse(error)) => return Err(Error::db(error)), - _ => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), + None => return Err(Error::closed()), }, }; diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index 59e31fea6..f997e9433 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -114,7 +114,7 @@ where let rows = extract_row_affected(&body)?; return Poll::Ready(Ok(rows)); } - _ => return Poll::Ready(Err(Error::unexpected_message())), + m => return Poll::Ready(Err(Error::unexpected_message(m))), } } } @@ -206,13 +206,19 @@ where .map_err(|_| Error::closed())?; match responses.next().await? { + Message::ParseComplete => { + match responses.next().await? { + Message::BindComplete => {} + m => return Err(Error::unexpected_message(m)), + } + } Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } match responses.next().await? { Message::CopyInResponse(_) => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(CopyInSink { diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs index 1e6949252..4141bee92 100644 --- a/tokio-postgres/src/copy_out.rs +++ b/tokio-postgres/src/copy_out.rs @@ -26,13 +26,17 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { + Message::ParseComplete => match responses.next().await? { + Message::BindComplete => {} + m => return Err(Error::unexpected_message(m)), + }, Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } match responses.next().await? { Message::CopyOutResponse(_) => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(responses) @@ -56,7 +60,7 @@ impl Stream for CopyOutStream { match ready!(this.responses.poll_next(cx)?) { Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), Message::CopyDone => Poll::Ready(None), - _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + m => Poll::Ready(Some(Err(Error::unexpected_message(m)))), } } } diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index f1e2644c6..764f77f9c 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -1,7 +1,7 @@ //! Errors. use fallible_iterator::FallibleIterator; -use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody}; +use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody, Message}; use std::error::{self, Error as _Error}; use std::fmt; use std::io; @@ -339,7 +339,7 @@ pub enum ErrorPosition { #[derive(Debug, PartialEq)] enum Kind { Io, - UnexpectedMessage, + UnexpectedMessage(Message), Tls, ToSql(usize), FromSql(usize), @@ -379,7 +379,9 @@ impl fmt::Display for Error { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.0.kind { Kind::Io => fmt.write_str("error communicating with the server")?, - Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?, + Kind::UnexpectedMessage(msg) => { + write!(fmt, "unexpected message from server: {:?}", msg)? + } Kind::Tls => fmt.write_str("error performing TLS handshake")?, Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?, Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?, @@ -445,8 +447,8 @@ impl Error { Error::new(Kind::Closed, None) } - pub(crate) fn unexpected_message() -> Error { - Error::new(Kind::UnexpectedMessage, None) + pub(crate) fn unexpected_message(message: Message) -> Error { + Error::new(Kind::UnexpectedMessage(message), None) } #[allow(clippy::needless_pass_by_value)] diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index 0abb8e453..9895aa0d4 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -62,25 +62,30 @@ pub async fn prepare( client: &Arc, query: &str, types: &[Type], + unnamed: bool, ) -> Result { - let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); + let name = if unnamed { + String::new() + } else { + format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)) + }; let buf = encode(client, &name, query, types)?; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { Message::ParseComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } let parameter_description = match responses.next().await? { Message::ParameterDescription(body) => body, - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), }; let row_description = match responses.next().await? { Message::RowDescription(body) => Some(body), Message::NoData => None, - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), }; let mut parameters = vec![]; @@ -100,7 +105,11 @@ pub async fn prepare( } } - Ok(Statement::new(client, name, parameters, columns)) + if unnamed { + Ok(Statement::unnamed(query.to_owned(), parameters, columns)) + } else { + Ok(Statement::named(client, name, parameters, columns)) + } } fn prepare_rec<'a>( @@ -108,7 +117,7 @@ fn prepare_rec<'a>( query: &'a str, types: &'a [Type], ) -> Pin> + 'a + Send>> { - Box::pin(prepare(client, query, types)) + Box::pin(prepare(client, query, types, false)) } fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { @@ -142,7 +151,7 @@ pub async fn get_type(client: &Arc, oid: Oid) -> Result row, - None => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), }; let name: String = row.try_get(0)?; diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 7cf9580e5..8b7e048e8 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -186,7 +186,7 @@ where } Message::EmptyQueryResponse => rows = 0, Message::ReadyForQuery(_) => return Ok(rows), - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } } } @@ -195,8 +195,12 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { + Message::ParseComplete => match responses.next().await? { + Message::BindComplete => {} + m => return Err(Error::unexpected_message(m)), + }, Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(responses) @@ -209,6 +213,9 @@ where I::IntoIter: ExactSizeIterator, { client.with_buf(|buf| { + if let Some(query) = statement.query() { + frontend::parse("", query, [], buf).unwrap(); + } encode_bind(statement, params, "", buf)?; frontend::execute("", 0, buf).map_err(Error::encode)?; frontend::sync(buf); @@ -305,7 +312,7 @@ impl Stream for RowStream { *this.status = Some(status.status()); return Poll::Ready(None); } - _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + m => return Poll::Ready(Some(Err(Error::unexpected_message(m)))), } } } diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index bcc6d928b..9838b0809 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -58,7 +58,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro | Message::EmptyQueryResponse | Message::RowDescription(_) | Message::DataRow(_) => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } } } @@ -107,12 +107,12 @@ impl Stream for SimpleQueryStream { Message::DataRow(body) => { let row = match &this.columns { Some(columns) => SimpleQueryRow::new(columns.clone(), body)?, - None => return Poll::Ready(Some(Err(Error::unexpected_message()))), + None => return Poll::Ready(Some(Err(Error::closed()))), }; return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))); } Message::ReadyForQuery(_) => return Poll::Ready(None), - _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + m => return Poll::Ready(Some(Err(Error::unexpected_message(m)))), } } } diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 246d36a57..920bd74da 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -11,22 +11,31 @@ use std::{ sync::{Arc, Weak}, }; -struct StatementInner { - client: Weak, - name: String, - params: Vec, - columns: Vec, +enum StatementInner { + Unnamed { + query: String, + params: Vec, + columns: Vec, + }, + Named { + client: Weak, + name: String, + params: Vec, + columns: Vec, + }, } impl Drop for StatementInner { fn drop(&mut self) { - if let Some(client) = self.client.upgrade() { - let buf = client.with_buf(|buf| { - frontend::close(b'S', &self.name, buf).unwrap(); - frontend::sync(buf); - buf.split().freeze() - }); - let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); + if let StatementInner::Named { client, name, .. } = self { + if let Some(client) = client.upgrade() { + let buf = client.with_buf(|buf| { + frontend::close(b'S', name, buf).unwrap(); + frontend::sync(buf); + buf.split().freeze() + }); + let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } } } } @@ -38,13 +47,13 @@ impl Drop for StatementInner { pub struct Statement(Arc); impl Statement { - pub(crate) fn new( + pub(crate) fn named( inner: &Arc, name: String, params: Vec, columns: Vec, ) -> Statement { - Statement(Arc::new(StatementInner { + Statement(Arc::new(StatementInner::Named { client: Arc::downgrade(inner), name, params, @@ -52,18 +61,42 @@ impl Statement { })) } + pub(crate) fn unnamed(query: String, params: Vec, columns: Vec) -> Self { + Statement(Arc::new(StatementInner::Unnamed { + query, + params, + columns, + })) + } + pub(crate) fn name(&self) -> &str { - &self.0.name + match &*self.0 { + StatementInner::Unnamed { .. } => "", + StatementInner::Named { name, .. } => name, + } + } + + pub(crate) fn query(&self) -> Option<&str> { + match &*self.0 { + StatementInner::Unnamed { query, .. } => Some(query), + StatementInner::Named { .. } => None, + } } /// Returns the expected types of the statement's parameters. pub fn params(&self) -> &[Type] { - &self.0.params + match &*self.0 { + StatementInner::Unnamed { params, .. } => params, + StatementInner::Named { params, .. } => params, + } } /// Returns information about the columns returned when the statement is queried. pub fn columns(&self) -> &[Column] { - &self.0.columns + match &*self.0 { + StatementInner::Unnamed { columns, .. } => columns, + StatementInner::Named { columns, .. } => columns, + } } } diff --git a/tokio-postgres/src/to_statement.rs b/tokio-postgres/src/to_statement.rs index 427f77dd7..ef1e65272 100644 --- a/tokio-postgres/src/to_statement.rs +++ b/tokio-postgres/src/to_statement.rs @@ -15,7 +15,7 @@ mod private { pub async fn into_statement(self, client: &Client) -> Result { match self { ToStatementType::Statement(s) => Ok(s.clone()), - ToStatementType::Query(s) => client.prepare(s).await, + ToStatementType::Query(s) => client.prepare_unnamed(s).await, } } } From 390d4cee702015455ac6e5cfc1d382b1134be8dd Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Mon, 6 Nov 2023 11:18:15 +0100 Subject: [PATCH 07/10] Enable transactional pool mode configuration --- postgres/src/config.rs | 14 ++++++++++++++ tokio-postgres/src/client.rs | 19 +++++++++++++++---- tokio-postgres/src/config.rs | 16 ++++++++++++++++ tokio-postgres/src/connect_raw.rs | 10 +++++++++- 4 files changed, 54 insertions(+), 5 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index a32ddc78e..ddac7111a 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -433,6 +433,20 @@ impl Config { self } + /// When enabled, the client skips all internal caching for statements, + /// allowing usage with a connection pool with transaction mode. + /// + /// Defaults to `false`. + pub fn transaction_pool_mode(&mut self, enable: bool) -> &mut Config { + self.config.transaction_pool_mode(enable); + self + } + + /// Gets the transaction pool mode status. + pub fn get_transaction_pool_mode(&self) -> bool { + self.config.get_transaction_pool_mode() + } + /// Opens a connection to a PostgreSQL database. pub fn connect(&self, tls: T) -> Result where diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index eb7b8bf0a..17e948564 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -83,6 +83,7 @@ struct CachedTypeInfo { pub struct InnerClient { sender: mpsc::UnboundedSender, + transaction_pool_mode: bool, cached_typeinfo: Mutex, /// A buffer to use when writing out postgres commands. @@ -108,7 +109,9 @@ impl InnerClient { } pub fn set_typeinfo(&self, statement: &Statement) { - self.cached_typeinfo.lock().typeinfo = Some(statement.clone()); + if !self.transaction_pool_mode { + self.cached_typeinfo.lock().typeinfo = Some(statement.clone()); + } } pub fn typeinfo_composite(&self) -> Option { @@ -116,7 +119,9 @@ impl InnerClient { } pub fn set_typeinfo_composite(&self, statement: &Statement) { - self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone()); + if !self.transaction_pool_mode { + self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone()); + } } pub fn typeinfo_enum(&self) -> Option { @@ -124,7 +129,9 @@ impl InnerClient { } pub fn set_typeinfo_enum(&self, statement: &Statement) { - self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone()); + if !self.transaction_pool_mode { + self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone()); + } } pub fn type_(&self, oid: Oid) -> Option { @@ -132,7 +139,9 @@ impl InnerClient { } pub fn set_type(&self, oid: Oid, type_: &Type) { - self.cached_typeinfo.lock().types.insert(oid, type_.clone()); + if !self.transaction_pool_mode { + self.cached_typeinfo.lock().types.insert(oid, type_.clone()); + } } pub fn clear_type_cache(&self) { @@ -190,12 +199,14 @@ impl Client { ssl_mode: SslMode, process_id: i32, secret_key: i32, + transaction_pool_mode: bool, ) -> Client { Client { inner: Arc::new(InnerClient { sender, cached_typeinfo: Default::default(), buffer: Default::default(), + transaction_pool_mode, }), #[cfg(feature = "runtime")] socket_config: None, diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 2547469ec..59530994c 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -208,6 +208,7 @@ pub struct Config { pub(crate) channel_binding: ChannelBinding, pub(crate) load_balance_hosts: LoadBalanceHosts, pub(crate) max_backend_message_size: Option, + pub(crate) transaction_pool_mode: bool, } impl Default for Config { @@ -242,6 +243,7 @@ impl Config { channel_binding: ChannelBinding::Prefer, load_balance_hosts: LoadBalanceHosts::Disable, max_backend_message_size: None, + transaction_pool_mode: false, } } @@ -509,6 +511,20 @@ impl Config { self.channel_binding } + /// When enabled, the client skips all internal caching for statements, + /// allowing usage with a connection pool with transaction mode. + /// + /// Defaults to `false`. + pub fn transaction_pool_mode(&mut self, enable: bool) -> &mut Config { + self.transaction_pool_mode = enable; + self + } + + /// Gets the transaction pool mode status. + pub fn get_transaction_pool_mode(&self) -> bool { + self.transaction_pool_mode + } + /// Sets the host load balancing behavior. /// /// Defaults to `disable`. diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index b468c5f32..2d6b55e4d 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -112,7 +112,15 @@ where let (process_id, secret_key, parameters) = read_info(&mut stream).await?; let (sender, receiver) = mpsc::unbounded(); - let client = Client::new(sender, config.ssl_mode, process_id, secret_key); + + let client = Client::new( + sender, + config.ssl_mode, + process_id, + secret_key, + config.transaction_pool_mode, + ); + let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver); Ok((client, connection)) From a9f7a86fe34c4222890228c0d049eefca8a87a4b Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Mon, 6 Nov 2023 15:54:09 +0100 Subject: [PATCH 08/10] Be really sure no statements are ever used in transaction_pool_mode --- tokio-postgres/src/client.rs | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 17e948564..0a3226fa6 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -105,7 +105,11 @@ impl InnerClient { } pub fn typeinfo(&self) -> Option { - self.cached_typeinfo.lock().typeinfo.clone() + if self.transaction_pool_mode { + None + } else { + self.cached_typeinfo.lock().typeinfo.clone() + } } pub fn set_typeinfo(&self, statement: &Statement) { @@ -115,7 +119,11 @@ impl InnerClient { } pub fn typeinfo_composite(&self) -> Option { - self.cached_typeinfo.lock().typeinfo_composite.clone() + if self.transaction_pool_mode { + None + } else { + self.cached_typeinfo.lock().typeinfo_composite.clone() + } } pub fn set_typeinfo_composite(&self, statement: &Statement) { @@ -125,7 +133,11 @@ impl InnerClient { } pub fn typeinfo_enum(&self) -> Option { - self.cached_typeinfo.lock().typeinfo_enum.clone() + if self.transaction_pool_mode { + None + } else { + self.cached_typeinfo.lock().typeinfo_enum.clone() + } } pub fn set_typeinfo_enum(&self, statement: &Statement) { @@ -135,7 +147,11 @@ impl InnerClient { } pub fn type_(&self, oid: Oid) -> Option { - self.cached_typeinfo.lock().types.get(&oid).cloned() + if self.transaction_pool_mode { + None + } else { + self.cached_typeinfo.lock().types.get(&oid).cloned() + } } pub fn set_type(&self, oid: Oid, type_: &Type) { @@ -145,7 +161,9 @@ impl InnerClient { } pub fn clear_type_cache(&self) { - self.cached_typeinfo.lock().types.clear(); + if !self.transaction_pool_mode { + self.cached_typeinfo.lock().types.clear(); + } } /// Call the given function with a buffer to be used when writing out From 2120af910c20072c02cf50c1f38ecb18d0714658 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Mon, 6 Nov 2023 17:32:47 +0100 Subject: [PATCH 09/10] fix(pool): Skip all type info queries, we run in text mode --- postgres/src/config.rs | 14 --- tokio-postgres/src/client.rs | 83 +------------ tokio-postgres/src/config.rs | 16 --- tokio-postgres/src/connect_raw.rs | 9 +- tokio-postgres/src/prepare.rs | 181 +---------------------------- tokio-postgres/src/query.rs | 1 + tokio-postgres/src/statement.rs | 3 +- tokio-postgres/src/to_statement.rs | 1 + 8 files changed, 11 insertions(+), 297 deletions(-) diff --git a/postgres/src/config.rs b/postgres/src/config.rs index ddac7111a..a32ddc78e 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -433,20 +433,6 @@ impl Config { self } - /// When enabled, the client skips all internal caching for statements, - /// allowing usage with a connection pool with transaction mode. - /// - /// Defaults to `false`. - pub fn transaction_pool_mode(&mut self, enable: bool) -> &mut Config { - self.config.transaction_pool_mode(enable); - self - } - - /// Gets the transaction pool mode status. - pub fn get_transaction_pool_mode(&self) -> bool { - self.config.get_transaction_pool_mode() - } - /// Opens a connection to a PostgreSQL database. pub fn connect(&self, tls: T) -> Result where diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 0a3226fa6..70b409633 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -61,30 +61,9 @@ impl Responses { } } -/// A cache of type info and prepared statements for fetching type info -/// (corresponding to the queries in the [prepare](prepare) module). -#[derive(Default)] -struct CachedTypeInfo { - /// A statement for basic information for a type from its - /// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its - /// fallback). - typeinfo: Option, - /// A statement for getting information for a composite type from its OID. - /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY). - typeinfo_composite: Option, - /// A statement for getting information for a composite type from its OID. - /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or - /// its fallback). - typeinfo_enum: Option, - - /// Cache of types already looked up. - types: HashMap, -} - pub struct InnerClient { sender: mpsc::UnboundedSender, - transaction_pool_mode: bool, - cached_typeinfo: Mutex, + cached_typeinfo: Mutex>, /// A buffer to use when writing out postgres commands. buffer: Mutex, @@ -104,66 +83,12 @@ impl InnerClient { }) } - pub fn typeinfo(&self) -> Option { - if self.transaction_pool_mode { - None - } else { - self.cached_typeinfo.lock().typeinfo.clone() - } - } - - pub fn set_typeinfo(&self, statement: &Statement) { - if !self.transaction_pool_mode { - self.cached_typeinfo.lock().typeinfo = Some(statement.clone()); - } - } - - pub fn typeinfo_composite(&self) -> Option { - if self.transaction_pool_mode { - None - } else { - self.cached_typeinfo.lock().typeinfo_composite.clone() - } - } - - pub fn set_typeinfo_composite(&self, statement: &Statement) { - if !self.transaction_pool_mode { - self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone()); - } - } - - pub fn typeinfo_enum(&self) -> Option { - if self.transaction_pool_mode { - None - } else { - self.cached_typeinfo.lock().typeinfo_enum.clone() - } - } - - pub fn set_typeinfo_enum(&self, statement: &Statement) { - if !self.transaction_pool_mode { - self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone()); - } - } - pub fn type_(&self, oid: Oid) -> Option { - if self.transaction_pool_mode { - None - } else { - self.cached_typeinfo.lock().types.get(&oid).cloned() - } - } - - pub fn set_type(&self, oid: Oid, type_: &Type) { - if !self.transaction_pool_mode { - self.cached_typeinfo.lock().types.insert(oid, type_.clone()); - } + self.cached_typeinfo.lock().get(&oid).cloned() } pub fn clear_type_cache(&self) { - if !self.transaction_pool_mode { - self.cached_typeinfo.lock().types.clear(); - } + self.cached_typeinfo.lock().clear(); } /// Call the given function with a buffer to be used when writing out @@ -217,14 +142,12 @@ impl Client { ssl_mode: SslMode, process_id: i32, secret_key: i32, - transaction_pool_mode: bool, ) -> Client { Client { inner: Arc::new(InnerClient { sender, cached_typeinfo: Default::default(), buffer: Default::default(), - transaction_pool_mode, }), #[cfg(feature = "runtime")] socket_config: None, diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 59530994c..2547469ec 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -208,7 +208,6 @@ pub struct Config { pub(crate) channel_binding: ChannelBinding, pub(crate) load_balance_hosts: LoadBalanceHosts, pub(crate) max_backend_message_size: Option, - pub(crate) transaction_pool_mode: bool, } impl Default for Config { @@ -243,7 +242,6 @@ impl Config { channel_binding: ChannelBinding::Prefer, load_balance_hosts: LoadBalanceHosts::Disable, max_backend_message_size: None, - transaction_pool_mode: false, } } @@ -511,20 +509,6 @@ impl Config { self.channel_binding } - /// When enabled, the client skips all internal caching for statements, - /// allowing usage with a connection pool with transaction mode. - /// - /// Defaults to `false`. - pub fn transaction_pool_mode(&mut self, enable: bool) -> &mut Config { - self.transaction_pool_mode = enable; - self - } - - /// Gets the transaction pool mode status. - pub fn get_transaction_pool_mode(&self) -> bool { - self.transaction_pool_mode - } - /// Sets the host load balancing behavior. /// /// Defaults to `disable`. diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 2d6b55e4d..964612b4e 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -113,14 +113,7 @@ where let (sender, receiver) = mpsc::unbounded(); - let client = Client::new( - sender, - config.ssl_mode, - process_id, - secret_key, - config.transaction_pool_mode, - ); - + let client = Client::new(sender, config.ssl_mode, process_id, secret_key); let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver); Ok((client, connection)) diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index 9895aa0d4..c17ae0e74 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -1,61 +1,16 @@ use crate::client::InnerClient; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::error::SqlState; -use crate::types::{Field, Kind, Oid, Type}; -use crate::{query, slice_iter}; +use crate::types::{Oid, Type}; use crate::{Column, Error, Statement}; use bytes::Bytes; use fallible_iterator::FallibleIterator; -use futures_util::{pin_mut, TryStreamExt}; use log::debug; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; -use std::future::Future; -use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -const TYPEINFO_QUERY: &str = "\ -SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid -FROM pg_catalog.pg_type t -LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid -INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid -WHERE t.oid = $1 -"; - -// Range types weren't added until Postgres 9.2, so pg_range may not exist -const TYPEINFO_FALLBACK_QUERY: &str = "\ -SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid -FROM pg_catalog.pg_type t -INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid -WHERE t.oid = $1 -"; - -const TYPEINFO_ENUM_QUERY: &str = "\ -SELECT enumlabel -FROM pg_catalog.pg_enum -WHERE enumtypid = $1 -ORDER BY enumsortorder -"; - -// Postgres 9.0 didn't have enumsortorder -const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\ -SELECT enumlabel -FROM pg_catalog.pg_enum -WHERE enumtypid = $1 -ORDER BY oid -"; - -const TYPEINFO_COMPOSITE_QUERY: &str = "\ -SELECT attname, atttypid -FROM pg_catalog.pg_attribute -WHERE attrelid = $1 -AND NOT attisdropped -AND attnum > 0 -ORDER BY attnum -"; - static NEXT_ID: AtomicUsize = AtomicUsize::new(0); pub async fn prepare( @@ -69,6 +24,7 @@ pub async fn prepare( } else { format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)) }; + let buf = encode(client, &name, query, types)?; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; @@ -112,14 +68,6 @@ pub async fn prepare( } } -fn prepare_rec<'a>( - client: &'a Arc, - query: &'a str, - types: &'a [Type], -) -> Pin> + 'a + Send>> { - Box::pin(prepare(client, query, types, false)) -} - fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { if types.is_empty() { debug!("preparing query {}: {}", name, query); @@ -144,128 +92,5 @@ pub async fn get_type(client: &Arc, oid: Oid) -> Result row, - None => return Err(Error::closed()), - }; - - let name: String = row.try_get(0)?; - let type_: i8 = row.try_get(1)?; - let elem_oid: Oid = row.try_get(2)?; - let rngsubtype: Option = row.try_get(3)?; - let basetype: Oid = row.try_get(4)?; - let schema: String = row.try_get(5)?; - let relid: Oid = row.try_get(6)?; - - let kind = if type_ == b'e' as i8 { - let variants = get_enum_variants(client, oid).await?; - Kind::Enum(variants) - } else if type_ == b'p' as i8 { - Kind::Pseudo - } else if basetype != 0 { - let type_ = get_type_rec(client, basetype).await?; - Kind::Domain(type_) - } else if elem_oid != 0 { - let type_ = get_type_rec(client, elem_oid).await?; - Kind::Array(type_) - } else if relid != 0 { - let fields = get_composite_fields(client, relid).await?; - Kind::Composite(fields) - } else if let Some(rngsubtype) = rngsubtype { - let type_ = get_type_rec(client, rngsubtype).await?; - Kind::Range(type_) - } else { - Kind::Simple - }; - - let type_ = Type::new(name, oid, kind, schema); - client.set_type(oid, &type_); - - Ok(type_) -} - -fn get_type_rec<'a>( - client: &'a Arc, - oid: Oid, -) -> Pin> + Send + 'a>> { - Box::pin(get_type(client, oid)) -} - -async fn typeinfo_statement(client: &Arc) -> Result { - if let Some(stmt) = client.typeinfo() { - return Ok(stmt); - } - - let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await { - Ok(stmt) => stmt, - Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => { - prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await? - } - Err(e) => return Err(e), - }; - - client.set_typeinfo(&stmt); - Ok(stmt) -} - -async fn get_enum_variants(client: &Arc, oid: Oid) -> Result, Error> { - let stmt = typeinfo_enum_statement(client).await?; - - query::query(client, stmt, slice_iter(&[&oid])) - .await? - .and_then(|row| async move { row.try_get(0) }) - .try_collect() - .await -} - -async fn typeinfo_enum_statement(client: &Arc) -> Result { - if let Some(stmt) = client.typeinfo_enum() { - return Ok(stmt); - } - - let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await { - Ok(stmt) => stmt, - Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => { - prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await? - } - Err(e) => return Err(e), - }; - - client.set_typeinfo_enum(&stmt); - Ok(stmt) -} - -async fn get_composite_fields(client: &Arc, oid: Oid) -> Result, Error> { - let stmt = typeinfo_composite_statement(client).await?; - - let rows = query::query(client, stmt, slice_iter(&[&oid])) - .await? - .try_collect::>() - .await?; - - let mut fields = vec![]; - for row in rows { - let name = row.try_get(0)?; - let oid = row.try_get(1)?; - let type_ = get_type_rec(client, oid).await?; - fields.push(Field::new(name, type_)); - } - - Ok(fields) -} - -async fn typeinfo_composite_statement(client: &Arc) -> Result { - if let Some(stmt) = client.typeinfo_composite() { - return Ok(stmt); - } - - let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?; - - client.set_typeinfo_composite(&stmt); - Ok(stmt) + Ok(Type::TEXT) } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 8b7e048e8..b1ba91821 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -106,6 +106,7 @@ where // now read the responses let responses = start(client, buf).await?; + Ok(RowStream { statement, responses, diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 920bd74da..8b629732c 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -11,6 +11,7 @@ use std::{ sync::{Arc, Weak}, }; +#[derive(Debug)] enum StatementInner { Unnamed { query: String, @@ -43,7 +44,7 @@ impl Drop for StatementInner { /// A prepared statement. /// /// Prepared statements can only be used with the connection that created them. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Statement(Arc); impl Statement { diff --git a/tokio-postgres/src/to_statement.rs b/tokio-postgres/src/to_statement.rs index ef1e65272..cbf24ee0f 100644 --- a/tokio-postgres/src/to_statement.rs +++ b/tokio-postgres/src/to_statement.rs @@ -6,6 +6,7 @@ mod private { pub trait Sealed {} + #[derive(Debug)] pub enum ToStatementType<'a> { Statement(&'a Statement), Query(&'a str), From e30b9c2df70a929dd2f3f1c63b86c484d26a9e98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20Houl=C3=A9?= Date: Tue, 7 Nov 2023 12:53:39 +0100 Subject: [PATCH 10/10] tokio-postgres: prepare and execute unnamed statements in one roundtrip --- tokio-postgres/src/client.rs | 10 +-- tokio-postgres/src/generic_client.rs | 11 +-- tokio-postgres/src/prepare.rs | 18 ++-- tokio-postgres/src/query.rs | 119 ++++++++++++++++++++++----- tokio-postgres/src/transaction.rs | 5 +- 5 files changed, 117 insertions(+), 46 deletions(-) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 70b409633..5e4f95d2a 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -4,7 +4,7 @@ use crate::connection::{Request, RequestMessages}; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; -use crate::query::RowStream; +use crate::query::{RowStream, }; use crate::simple_query::SimpleQueryStream; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; @@ -326,19 +326,17 @@ impl Client { /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip - pub async fn query_raw_txt<'a, T, S, I>( + pub async fn query_raw_txt<'a, S, I>( &self, - statement: &T, + query: &str, params: I, ) -> Result where - T: ?Sized + ToStatement, S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, { - let statement = statement.__convert().into_statement(self).await?; - query::query_txt(&self.inner, statement, params).await + query::query_txt(&self.inner, query, params).await } /// Executes a statement, returning the number of rows modified. diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index a4ee4808b..559787e2d 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -57,13 +57,12 @@ pub trait GenericClient: private::Sealed { I::IntoIter: ExactSizeIterator; /// Like `Client::query_raw_txt`. - async fn query_raw_txt<'a, T, S, I>( + async fn query_raw_txt<'a, S, I>( &self, - statement: &T, + statement: &str, params: I, ) -> Result where - T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send; @@ -148,9 +147,8 @@ impl GenericClient for Client { self.query_raw(statement, params).await } - async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result + async fn query_raw_txt<'a, S, I>(&self, statement: &str, params: I) -> Result where - T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send, @@ -244,9 +242,8 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } - async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result + async fn query_raw_txt<'a, S, I>(&self, statement: &str, params: I) -> Result where - T: ?Sized + ToStatement + Sync + Send, S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send, diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index c17ae0e74..b49db0343 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -47,7 +47,7 @@ pub async fn prepare( let mut parameters = vec![]; let mut it = parameter_description.parameters(); while let Some(oid) = it.next().map_err(Error::parse)? { - let type_ = get_type(client, oid).await?; + let type_ = get_type(client, oid); parameters.push(type_); } @@ -55,7 +55,7 @@ pub async fn prepare( if let Some(row_description) = row_description { let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { - let type_ = get_type(client, field.type_oid()).await?; + let type_ = get_type(client, field.type_oid()); let column = Column::new(field.name().to_string(), type_, field); columns.push(column); } @@ -68,7 +68,7 @@ pub async fn prepare( } } -fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { +pub(crate) fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { if types.is_empty() { debug!("preparing query {}: {}", name, query); } else { @@ -83,14 +83,14 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -pub async fn get_type(client: &Arc, oid: Oid) -> Result { +pub fn get_type(client: &InnerClient, oid: Oid) -> Type { if let Some(type_) = Type::from_oid(oid) { - return Ok(type_); + return type_; } - if let Some(type_) = client.type_(oid) { - return Ok(type_); - } + // if let Some(type_) = client.type_(oid) { + // return Ok(type_); + // } - Ok(Type::TEXT) + Type::TEXT } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index b1ba91821..3f26c33cc 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -2,14 +2,15 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::{BorrowToSql, IsNull}; -use crate::{Error, Portal, Row, Statement}; +use crate::{Error, Portal, Row, Statement, Column}; use bytes::{BufMut, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; -use postgres_protocol::message::backend::{CommandCompleteBody, Message}; +use postgres_protocol::message::backend::{CommandCompleteBody, Message, ParameterDescriptionBody, RowDescriptionBody}; use postgres_protocol::message::frontend; -use postgres_types::Format; +use postgres_types::{Format, Type}; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; @@ -50,7 +51,7 @@ where } else { encode(client, &statement, params)? }; - let responses = start(client, buf).await?; + let (statement, responses) = start(client, buf).await?; Ok(RowStream { statement, responses, @@ -58,13 +59,14 @@ where command_tag: None, status: None, output_format: Format::Binary, + parameter_description: None, _p: PhantomPinned, }) } pub async fn query_txt( client: &Arc, - statement: Statement, + query: &str, params: I, ) -> Result where @@ -72,13 +74,19 @@ where I: IntoIterator>, I::IntoIter: ExactSizeIterator, { + dbg!("here"); let params = params.into_iter(); let buf = client.with_buf(|buf| { + // prepare + frontend::parse("", query, std::iter::empty(), buf).map_err(Error::encode)?; + frontend::describe(b'S', "", buf).map_err(Error::encode)?; + frontend::flush(buf); + // Bind, pass params as text, retrieve as binary match frontend::bind( "", // empty string selects the unnamed portal - statement.name(), // named prepared statement + "", // unnamed prepared statement std::iter::empty(), // all parameters use the default format (text) params, |param, buf| match param { @@ -104,10 +112,13 @@ where Ok(buf.split().freeze()) })?; + dbg!("here"); // now read the responses - let responses = start(client, buf).await?; + let (statement, responses) = start(client, buf).await?; + dbg!("here"); Ok(RowStream { + parameter_description: None, statement, responses, command_tag: None, @@ -132,7 +143,8 @@ pub async fn query_portal( let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; Ok(RowStream { - statement: portal.statement().clone(), + parameter_description: None, + statement: Some(portal.statement().clone()), responses, rows_affected: None, command_tag: None, @@ -176,7 +188,7 @@ where } else { encode(client, &statement, params)? }; - let mut responses = start(client, buf).await?; + let (_statement, mut responses) = start(client, buf).await?; let mut rows = 0; loop { @@ -192,19 +204,49 @@ where } } -async fn start(client: &InnerClient, buf: Bytes) -> Result { +async fn start(client: &InnerClient, buf: Bytes) -> Result<(Option, Responses), Error> { + let mut parameter_description: Option = None; + let mut statement = None; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let make_statement = |parameter_description: ParameterDescriptionBody, row_description: Option| { + let mut parameters = vec![]; + let mut it = parameter_description.parameters(); + while let Some(oid) = it.next().map_err(Error::parse).unwrap() { + let type_ = crate::prepare::get_type(client, oid); + parameters.push(type_); + } - match responses.next().await? { - Message::ParseComplete => match responses.next().await? { - Message::BindComplete => {} - m => return Err(Error::unexpected_message(m)), - }, - Message::BindComplete => {} - m => return Err(Error::unexpected_message(m)), - } + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse).unwrap() { + let type_ = crate::prepare::get_type(client, field.type_oid()); + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + Statement::unnamed("Dose this matter?".to_owned(), parameters, columns) + + }; + + loop { + match responses.next().await? { + Message::ParseComplete => {}, + Message::BindComplete => {return Ok((statement, responses))} + Message::ParameterDescription(body) => { + parameter_description = Some(body); // to love me + } + Message::NoData => { + statement = Some(make_statement(parameter_description.take().unwrap(), None)); + } + Message::RowDescription(body) => { + statement = Some(make_statement(parameter_description.take().unwrap(), Some(body))); + } + m => return Err(Error::unexpected_message(m)), + } + } - Ok(responses) } pub fn encode(client: &InnerClient, statement: &Statement, params: I) -> Result @@ -276,12 +318,14 @@ where pin_project! { /// A stream of table rows. pub struct RowStream { - statement: Statement, + statement: Option, responses: Responses, rows_affected: Option, command_tag: Option, output_format: Format, status: Option, + parameter_description: Option, + #[pin] _p: PhantomPinned, } @@ -291,12 +335,44 @@ impl Stream for RowStream { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // dbg!("here"); let this = self.project(); + // let make_statement = |parameter_description: ParameterDescriptionBody, row_description: Option| { + // let mut parameters = vec![]; + // let mut it = parameter_description.parameters(); + // while let Some(oid) = it.next().map_err(Error::parse).unwrap() { + // let type_ = Type::TEXT; + // parameters.push(type_); + // } + + // let mut columns = vec![]; + // if let Some(row_description) = row_description { + // let mut it = row_description.fields(); + // while let Some(field) = it.next().map_err(Error::parse).unwrap() { + // let type_ = crate::prepare::get_type(client, field.type_oid()); + // let column = Column::new(field.name().to_string(), type_, field); + // columns.push(column); + // } + // } + + // Statement::unnamed("Dose this matter?".to_owned(), parameters, columns) + + // }; loop { match ready!(this.responses.poll_next(cx)?) { + // Message::ParseComplete => {} // nice + // Message::ParameterDescription(body) => { + // *this.parameter_description = Some(body); // to love me + // } + // Message::NoData => { + // *this.statement = Some(make_statement(this.parameter_description.take().unwrap(), None)); + // } + // Message::RowDescription(body) => { + // *this.statement = Some(make_statement(this.parameter_description.take().unwrap(), Some(body))); + // } Message::DataRow(body) => { return Poll::Ready(Some(Ok(Row::new( - this.statement.clone(), + this.statement.as_ref().unwrap().clone(), body, *this.output_format, )?))) @@ -341,3 +417,4 @@ impl RowStream { self.status } } + diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index ca386974e..7221d9cca 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -150,14 +150,13 @@ impl<'a> Transaction<'a> { } /// Like `Client::query_raw_txt`. - pub async fn query_raw_txt(&self, statement: &T, params: I) -> Result + pub async fn query_raw_txt< S, I>(&self, query: &str, params: I) -> Result where - T: ?Sized + ToStatement, S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, { - self.client.query_raw_txt(statement, params).await + self.client.query_raw_txt(query, params).await } /// Like `Client::execute`.