Skip to content

Commit

Permalink
Work with pools that don't support prepared statements
Browse files Browse the repository at this point in the history
Introduce a new `query_with_param_types` method that allows to specify Postgres type parameters. This obviated the need to use prepared statementsjust to obtain parameter types for a query. It then combines parse, bind, and execute in a single packet.

Related: sfackler#1017, sfackler#1067
  • Loading branch information
ramnivas committed Jun 4, 2024
1 parent 98f5a11 commit f397668
Show file tree
Hide file tree
Showing 7 changed files with 416 additions and 6 deletions.
82 changes: 82 additions & 0 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,88 @@ impl Client {
query::query(&self.inner, statement, params).await
}

/// Like `query`, but requires the types of query parameters to be explicitly specified.
///
/// Compared to `query`, this method allows performing queries without three round trips (for prepare, execute, and close). Thus,
/// this is suitable in environments where prepared statements aren't supported (such as Cloudflare Workers with Hyperdrive).
///
/// # Examples
///
/// ```no_run
/// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
/// use tokio_postgres::types::ToSql;
/// use tokio_postgres::types::Type;
/// use futures_util::{pin_mut, TryStreamExt};
///
/// let rows = client.query_with_param_types(
/// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
/// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)],
/// ).await?;
///
/// for row in rows {
/// let foo: i32 = row.get("foo");
/// println!("foo: {}", foo);
/// }
/// # Ok(())
/// # }
/// ```
pub async fn query_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.query_raw_with_param_types(statement, params)
.await?
.try_collect()
.await
}

/// The maximally flexible version of [`query_with_param_types`].
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The parameters must specify value along with their Postgres type. This allows performing
/// queries without three round trips (for prepare, execute, and close).
///
/// [`query_with_param_types`]: #method.query_with_param_types
///
/// # Examples
///
/// ```no_run
/// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
/// use tokio_postgres::types::ToSql;
/// use tokio_postgres::types::Type;
/// use futures_util::{pin_mut, TryStreamExt};
///
/// let mut it = client.query_raw_with_param_types(
/// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
/// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)],
/// ).await?;
///
/// pin_mut!(it);
/// while let Some(row) = it.try_next().await? {
/// let foo: i32 = row.get("foo");
/// println!("foo: {}", foo);
/// }
/// # Ok(())
/// # }
/// ```
pub async fn query_raw_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<RowStream, Error> {
fn slice_iter<'a>(
s: &'a [(&'a (dyn ToSql + Sync), Type)],
) -> impl ExactSizeIterator<Item = (&'a dyn ToSql, Type)> + 'a {
s.iter()
.map(|(param, param_type)| (*param as _, param_type.clone()))
}

query::query_with_param_types(&self.inner, statement, slice_iter(params)).await
}

/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
Expand Down
46 changes: 46 additions & 0 deletions tokio-postgres/src/generic_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ pub trait GenericClient: private::Sealed {
I: IntoIterator<Item = P> + Sync + Send,
I::IntoIter: ExactSizeIterator;

/// Like `Client::query_with_param_types`
async fn query_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error>;

/// Like `Client::query_raw_with_param_types`.
async fn query_raw_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<RowStream, Error>;

/// Like `Client::prepare`.
async fn prepare(&self, query: &str) -> Result<Statement, Error>;

Expand Down Expand Up @@ -136,6 +150,22 @@ impl GenericClient for Client {
self.query_raw(statement, params).await
}

async fn query_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.query_with_param_types(statement, params).await
}

async fn query_raw_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<RowStream, Error> {
self.query_raw_with_param_types(statement, params).await
}

async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare(query).await
}
Expand Down Expand Up @@ -222,6 +252,22 @@ impl GenericClient for Transaction<'_> {
self.query_raw(statement, params).await
}

async fn query_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<Vec<Row>, Error> {
self.query_with_param_types(statement, params).await
}

async fn query_raw_with_param_types(
&self,
statement: &str,
params: &[(&(dyn ToSql + Sync), Type)],
) -> Result<RowStream, Error> {
self.query_raw_with_param_types(statement, params).await
}

async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare(query).await
}
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/prepare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
})
}

async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
pub(crate) async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
if let Some(type_) = Type::from_oid(oid) {
return Ok(type_);
}
Expand Down
146 changes: 141 additions & 5 deletions tokio-postgres/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::prepare::get_type;
use crate::types::{BorrowToSql, IsNull};
use crate::{Error, Portal, Row, Statement};
use crate::{Column, Error, Portal, Row, Statement};
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Stream};
use log::{debug, log_enabled, Level};
use pin_project_lite::pin_project;
use postgres_protocol::message::backend::{CommandCompleteBody, Message};
use postgres_protocol::message::backend::{CommandCompleteBody, Message, RowDescriptionBody};
use postgres_protocol::message::frontend;
use postgres_types::Type;
use std::fmt;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
Expand Down Expand Up @@ -50,13 +54,125 @@ where
};
let responses = start(client, buf).await?;
Ok(RowStream {
statement,
statement: statement,
responses,
rows_affected: None,
_p: PhantomPinned,
})
}

enum QueryProcessingState {
Empty,
ParseCompleted,
BindCompleted,
ParameterDescribed,
Final(Vec<Column>),
}

/// State machine for processing messages for `query_with_param_types`.
impl QueryProcessingState {
pub async fn process_message(
self,
client: &Arc<InnerClient>,
message: Message,
) -> Result<Self, Error> {
match (self, message) {
(QueryProcessingState::Empty, Message::ParseComplete) => {
Ok(QueryProcessingState::ParseCompleted)
}
(QueryProcessingState::ParseCompleted, Message::BindComplete) => {
Ok(QueryProcessingState::BindCompleted)
}
(QueryProcessingState::BindCompleted, Message::ParameterDescription(_)) => {
Ok(QueryProcessingState::ParameterDescribed)
}
(
QueryProcessingState::ParameterDescribed,
Message::RowDescription(row_description),
) => Self::form_final(client, Some(row_description)).await,
(QueryProcessingState::ParameterDescribed, Message::NoData) => {
Self::form_final(client, None).await
}
(_, Message::ErrorResponse(body)) => Err(Error::db(body)),
_ => Err(Error::unexpected_message()),
}
}

async fn form_final(
client: &Arc<InnerClient>,
row_description: Option<RowDescriptionBody>,
) -> Result<Self, Error> {
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, field.type_oid()).await?;
let column = Column {
name: field.name().to_string(),
table_oid: Some(field.table_oid()).filter(|n| *n != 0),
column_id: Some(field.column_id()).filter(|n| *n != 0),
r#type: type_,
};
columns.push(column);
}
}

Ok(Self::Final(columns))
}
}

pub async fn query_with_param_types<'a, P, I>(
client: &Arc<InnerClient>,
query: &str,
params: I,
) -> Result<RowStream, Error>
where
P: BorrowToSql,
I: IntoIterator<Item = (P, Type)>,
I::IntoIter: ExactSizeIterator,
{
let (params, param_types): (Vec<_>, Vec<_>) = params.into_iter().unzip();

let params = params.into_iter();

let param_oids = param_types.iter().map(|t| t.oid()).collect::<Vec<_>>();

let params = params.into_iter();

let buf = client.with_buf(|buf| {
frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?;

encode_bind_with_statement_name_and_param_types("", &param_types, params, "", buf)?;

frontend::describe(b'S', "", buf).map_err(Error::encode)?;

frontend::execute("", 0, buf).map_err(Error::encode)?;

frontend::sync(buf);

Ok(buf.split().freeze())
})?;

let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;

let mut state = QueryProcessingState::Empty;

loop {
let message = responses.next().await?;

state = state.process_message(client, message).await?;

if let QueryProcessingState::Final(columns) = state {
return Ok(RowStream {
statement: Statement::unnamed(vec![], columns),
responses,
rows_affected: None,
_p: PhantomPinned,
});
}
}
}

pub async fn query_portal(
client: &InnerClient,
portal: &Portal,
Expand Down Expand Up @@ -164,7 +280,27 @@ where
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
let param_types = statement.params();
encode_bind_with_statement_name_and_param_types(
statement.name(),
statement.params(),
params,
portal,
buf,
)
}

fn encode_bind_with_statement_name_and_param_types<P, I>(
statement_name: &str,
param_types: &[Type],
params: I,
portal: &str,
buf: &mut BytesMut,
) -> Result<(), Error>
where
P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator,
{
let params = params.into_iter();

if param_types.len() != params.len() {
Expand All @@ -181,7 +317,7 @@ where
let mut error_idx = 0;
let r = frontend::bind(
portal,
statement.name(),
statement_name,
param_formats,
params.zip(param_types).enumerate(),
|(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {
Expand Down
13 changes: 13 additions & 0 deletions tokio-postgres/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ struct StatementInner {

impl Drop for StatementInner {
fn drop(&mut self) {
if self.name.is_empty() {
// Unnamed statements don't need to be closed
return;
}
if let Some(client) = self.client.upgrade() {
let buf = client.with_buf(|buf| {
frontend::close(b'S', &self.name, buf).unwrap();
Expand Down Expand Up @@ -46,6 +50,15 @@ impl Statement {
}))
}

pub(crate) fn unnamed(params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(Arc::new(StatementInner {
client: Weak::new(),
name: String::new(),
params,
columns,
}))
}

pub(crate) fn name(&self) -> &str {
&self.0.name
}
Expand Down
Loading

0 comments on commit f397668

Please sign in to comment.