Skip to content

Commit 257bcfd

Browse files
authored
Merge pull request #1147 from exograph/exograph
Work with pools that don't support prepared statements
2 parents 647a925 + 0fa3247 commit 257bcfd

File tree

7 files changed

+292
-4
lines changed

7 files changed

+292
-4
lines changed

tokio-postgres/src/client.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,54 @@ impl Client {
364364
query::query(&self.inner, statement, params).await
365365
}
366366

367+
/// Like `query`, but requires the types of query parameters to be explicitly specified.
368+
///
369+
/// Compared to `query`, this method allows performing queries without three round trips (for
370+
/// prepare, execute, and close) by requiring the caller to specify parameter values along with
371+
/// their Postgres type. Thus, this is suitable in environments where prepared statements aren't
372+
/// supported (such as Cloudflare Workers with Hyperdrive).
373+
///
374+
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the
375+
/// parameter of the list provided, 1-indexed.
376+
///
377+
/// # Examples
378+
///
379+
/// ```no_run
380+
/// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> {
381+
/// use tokio_postgres::types::ToSql;
382+
/// use tokio_postgres::types::Type;
383+
/// use futures_util::{pin_mut, TryStreamExt};
384+
///
385+
/// let rows = client.query_typed(
386+
/// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2",
387+
/// &[(&"first param", Type::TEXT), (&2i32, Type::INT4)],
388+
/// ).await?;
389+
///
390+
/// for row in rows {
391+
/// let foo: i32 = row.get("foo");
392+
/// println!("foo: {}", foo);
393+
/// }
394+
/// # Ok(())
395+
/// # }
396+
/// ```
397+
pub async fn query_typed(
398+
&self,
399+
statement: &str,
400+
params: &[(&(dyn ToSql + Sync), Type)],
401+
) -> Result<Vec<Row>, Error> {
402+
fn slice_iter<'a>(
403+
s: &'a [(&'a (dyn ToSql + Sync), Type)],
404+
) -> impl ExactSizeIterator<Item = (&'a dyn ToSql, Type)> + 'a {
405+
s.iter()
406+
.map(|(param, param_type)| (*param as _, param_type.clone()))
407+
}
408+
409+
query::query_typed(&self.inner, statement, slice_iter(params))
410+
.await?
411+
.try_collect()
412+
.await
413+
}
414+
367415
/// Executes a statement, returning the number of rows modified.
368416
///
369417
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list

tokio-postgres/src/generic_client.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ pub trait GenericClient: private::Sealed {
5656
I: IntoIterator<Item = P> + Sync + Send,
5757
I::IntoIter: ExactSizeIterator;
5858

59+
/// Like [`Client::query_typed`]
60+
async fn query_typed(
61+
&self,
62+
statement: &str,
63+
params: &[(&(dyn ToSql + Sync), Type)],
64+
) -> Result<Vec<Row>, Error>;
65+
5966
/// Like [`Client::prepare`].
6067
async fn prepare(&self, query: &str) -> Result<Statement, Error>;
6168

@@ -139,6 +146,14 @@ impl GenericClient for Client {
139146
self.query_raw(statement, params).await
140147
}
141148

149+
async fn query_typed(
150+
&self,
151+
statement: &str,
152+
params: &[(&(dyn ToSql + Sync), Type)],
153+
) -> Result<Vec<Row>, Error> {
154+
self.query_typed(statement, params).await
155+
}
156+
142157
async fn prepare(&self, query: &str) -> Result<Statement, Error> {
143158
self.prepare(query).await
144159
}
@@ -229,6 +244,14 @@ impl GenericClient for Transaction<'_> {
229244
self.query_raw(statement, params).await
230245
}
231246

247+
async fn query_typed(
248+
&self,
249+
statement: &str,
250+
params: &[(&(dyn ToSql + Sync), Type)],
251+
) -> Result<Vec<Row>, Error> {
252+
self.query_typed(statement, params).await
253+
}
254+
232255
async fn prepare(&self, query: &str) -> Result<Statement, Error> {
233256
self.prepare(query).await
234257
}

tokio-postgres/src/prepare.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
131131
})
132132
}
133133

134-
async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
134+
pub(crate) async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
135135
if let Some(type_) = Type::from_oid(oid) {
136136
return Ok(type_);
137137
}

tokio-postgres/src/query.rs

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
use crate::client::{InnerClient, Responses};
22
use crate::codec::FrontendMessage;
33
use crate::connection::RequestMessages;
4+
use crate::prepare::get_type;
45
use crate::types::{BorrowToSql, IsNull};
5-
use crate::{Error, Portal, Row, Statement};
6+
use crate::{Column, Error, Portal, Row, Statement};
67
use bytes::{Bytes, BytesMut};
8+
use fallible_iterator::FallibleIterator;
79
use futures_util::{ready, Stream};
810
use log::{debug, log_enabled, Level};
911
use pin_project_lite::pin_project;
1012
use postgres_protocol::message::backend::{CommandCompleteBody, Message};
1113
use postgres_protocol::message::frontend;
14+
use postgres_types::Type;
1215
use std::fmt;
1316
use std::marker::PhantomPinned;
1417
use std::pin::Pin;
18+
use std::sync::Arc;
1519
use std::task::{Context, Poll};
1620

1721
struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
@@ -57,6 +61,71 @@ where
5761
})
5862
}
5963

64+
pub async fn query_typed<'a, P, I>(
65+
client: &Arc<InnerClient>,
66+
query: &str,
67+
params: I,
68+
) -> Result<RowStream, Error>
69+
where
70+
P: BorrowToSql,
71+
I: IntoIterator<Item = (P, Type)>,
72+
I::IntoIter: ExactSizeIterator,
73+
{
74+
let (params, param_types): (Vec<_>, Vec<_>) = params.into_iter().unzip();
75+
76+
let params = params.into_iter();
77+
78+
let param_oids = param_types.iter().map(|t| t.oid()).collect::<Vec<_>>();
79+
80+
let params = params.into_iter();
81+
82+
let buf = client.with_buf(|buf| {
83+
frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?;
84+
85+
encode_bind_with_statement_name_and_param_types("", &param_types, params, "", buf)?;
86+
87+
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
88+
89+
frontend::execute("", 0, buf).map_err(Error::encode)?;
90+
91+
frontend::sync(buf);
92+
93+
Ok(buf.split().freeze())
94+
})?;
95+
96+
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
97+
98+
loop {
99+
match responses.next().await? {
100+
Message::ParseComplete
101+
| Message::BindComplete
102+
| Message::ParameterDescription(_)
103+
| Message::NoData => {}
104+
Message::RowDescription(row_description) => {
105+
let mut columns: Vec<Column> = vec![];
106+
let mut it = row_description.fields();
107+
while let Some(field) = it.next().map_err(Error::parse)? {
108+
let type_ = get_type(client, field.type_oid()).await?;
109+
let column = Column {
110+
name: field.name().to_string(),
111+
table_oid: Some(field.table_oid()).filter(|n| *n != 0),
112+
column_id: Some(field.column_id()).filter(|n| *n != 0),
113+
r#type: type_,
114+
};
115+
columns.push(column);
116+
}
117+
return Ok(RowStream {
118+
statement: Statement::unnamed(vec![], columns),
119+
responses,
120+
rows_affected: None,
121+
_p: PhantomPinned,
122+
});
123+
}
124+
_ => return Err(Error::unexpected_message()),
125+
}
126+
}
127+
}
128+
60129
pub async fn query_portal(
61130
client: &InnerClient,
62131
portal: &Portal,
@@ -164,7 +233,27 @@ where
164233
I: IntoIterator<Item = P>,
165234
I::IntoIter: ExactSizeIterator,
166235
{
167-
let param_types = statement.params();
236+
encode_bind_with_statement_name_and_param_types(
237+
statement.name(),
238+
statement.params(),
239+
params,
240+
portal,
241+
buf,
242+
)
243+
}
244+
245+
fn encode_bind_with_statement_name_and_param_types<P, I>(
246+
statement_name: &str,
247+
param_types: &[Type],
248+
params: I,
249+
portal: &str,
250+
buf: &mut BytesMut,
251+
) -> Result<(), Error>
252+
where
253+
P: BorrowToSql,
254+
I: IntoIterator<Item = P>,
255+
I::IntoIter: ExactSizeIterator,
256+
{
168257
let params = params.into_iter();
169258

170259
if param_types.len() != params.len() {
@@ -181,7 +270,7 @@ where
181270
let mut error_idx = 0;
182271
let r = frontend::bind(
183272
portal,
184-
statement.name(),
273+
statement_name,
185274
param_formats,
186275
params.zip(param_types).enumerate(),
187276
|(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {

tokio-postgres/src/statement.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ struct StatementInner {
1414

1515
impl Drop for StatementInner {
1616
fn drop(&mut self) {
17+
if self.name.is_empty() {
18+
// Unnamed statements don't need to be closed
19+
return;
20+
}
1721
if let Some(client) = self.client.upgrade() {
1822
let buf = client.with_buf(|buf| {
1923
frontend::close(b'S', &self.name, buf).unwrap();
@@ -46,6 +50,15 @@ impl Statement {
4650
}))
4751
}
4852

53+
pub(crate) fn unnamed(params: Vec<Type>, columns: Vec<Column>) -> Statement {
54+
Statement(Arc::new(StatementInner {
55+
client: Weak::new(),
56+
name: String::new(),
57+
params,
58+
columns,
59+
}))
60+
}
61+
4962
pub(crate) fn name(&self) -> &str {
5063
&self.0.name
5164
}

tokio-postgres/src/transaction.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,15 @@ impl<'a> Transaction<'a> {
227227
query::query_portal(self.client.inner(), portal, max_rows).await
228228
}
229229

230+
/// Like `Client::query_typed`.
231+
pub async fn query_typed(
232+
&self,
233+
statement: &str,
234+
params: &[(&(dyn ToSql + Sync), Type)],
235+
) -> Result<Vec<Row>, Error> {
236+
self.client.query_typed(statement, params).await
237+
}
238+
230239
/// Like `Client::copy_in`.
231240
pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, Error>
232241
where

tokio-postgres/tests/test/main.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,3 +959,109 @@ async fn deferred_constraint() {
959959
.await
960960
.unwrap_err();
961961
}
962+
963+
#[tokio::test]
964+
async fn query_typed_no_transaction() {
965+
let client = connect("user=postgres").await;
966+
967+
client
968+
.batch_execute(
969+
"
970+
CREATE TEMPORARY TABLE foo (
971+
name TEXT,
972+
age INT
973+
);
974+
INSERT INTO foo (name, age) VALUES ('alice', 20), ('bob', 30), ('carol', 40);
975+
",
976+
)
977+
.await
978+
.unwrap();
979+
980+
let rows: Vec<tokio_postgres::Row> = client
981+
.query_typed(
982+
"SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age",
983+
&[(&"alice", Type::TEXT), (&50i32, Type::INT4)],
984+
)
985+
.await
986+
.unwrap();
987+
988+
assert_eq!(rows.len(), 2);
989+
let first_row = &rows[0];
990+
assert_eq!(first_row.get::<_, &str>(0), "bob");
991+
assert_eq!(first_row.get::<_, i32>(1), 30);
992+
assert_eq!(first_row.get::<_, &str>(2), "literal");
993+
assert_eq!(first_row.get::<_, i32>(3), 5);
994+
995+
let second_row = &rows[1];
996+
assert_eq!(second_row.get::<_, &str>(0), "carol");
997+
assert_eq!(second_row.get::<_, i32>(1), 40);
998+
assert_eq!(second_row.get::<_, &str>(2), "literal");
999+
assert_eq!(second_row.get::<_, i32>(3), 5);
1000+
}
1001+
1002+
#[tokio::test]
1003+
async fn query_typed_with_transaction() {
1004+
let mut client = connect("user=postgres").await;
1005+
1006+
client
1007+
.batch_execute(
1008+
"
1009+
CREATE TEMPORARY TABLE foo (
1010+
name TEXT,
1011+
age INT
1012+
);
1013+
",
1014+
)
1015+
.await
1016+
.unwrap();
1017+
1018+
let transaction = client.transaction().await.unwrap();
1019+
1020+
let rows: Vec<tokio_postgres::Row> = transaction
1021+
.query_typed(
1022+
"INSERT INTO foo (name, age) VALUES ($1, $2), ($3, $4), ($5, $6) returning name, age",
1023+
&[
1024+
(&"alice", Type::TEXT),
1025+
(&20i32, Type::INT4),
1026+
(&"bob", Type::TEXT),
1027+
(&30i32, Type::INT4),
1028+
(&"carol", Type::TEXT),
1029+
(&40i32, Type::INT4),
1030+
],
1031+
)
1032+
.await
1033+
.unwrap();
1034+
let inserted_values: Vec<(String, i32)> = rows
1035+
.iter()
1036+
.map(|row| (row.get::<_, String>(0), row.get::<_, i32>(1)))
1037+
.collect();
1038+
assert_eq!(
1039+
inserted_values,
1040+
[
1041+
("alice".to_string(), 20),
1042+
("bob".to_string(), 30),
1043+
("carol".to_string(), 40)
1044+
]
1045+
);
1046+
1047+
let rows: Vec<tokio_postgres::Row> = transaction
1048+
.query_typed(
1049+
"SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age",
1050+
&[(&"alice", Type::TEXT), (&50i32, Type::INT4)],
1051+
)
1052+
.await
1053+
.unwrap();
1054+
1055+
assert_eq!(rows.len(), 2);
1056+
let first_row = &rows[0];
1057+
assert_eq!(first_row.get::<_, &str>(0), "bob");
1058+
assert_eq!(first_row.get::<_, i32>(1), 30);
1059+
assert_eq!(first_row.get::<_, &str>(2), "literal");
1060+
assert_eq!(first_row.get::<_, i32>(3), 5);
1061+
1062+
let second_row = &rows[1];
1063+
assert_eq!(second_row.get::<_, &str>(0), "carol");
1064+
assert_eq!(second_row.get::<_, i32>(1), 40);
1065+
assert_eq!(second_row.get::<_, &str>(2), "literal");
1066+
assert_eq!(second_row.get::<_, i32>(3), 5);
1067+
}

0 commit comments

Comments
 (0)