Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support unnamed statements #1067

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion postgres-protocol/src/message/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ impl Header {
}

/// An enum representing Postgres backend messages.
#[derive(Debug, PartialEq)]
#[non_exhaustive]
pub enum Message {
AuthenticationCleartextPassword,
Expand Down Expand Up @@ -333,6 +334,7 @@ impl Read for Buffer {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationMd5PasswordBody {
salt: [u8; 4],
}
Expand All @@ -344,6 +346,7 @@ impl AuthenticationMd5PasswordBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationGssContinueBody(Bytes);

impl AuthenticationGssContinueBody {
Expand All @@ -353,6 +356,7 @@ impl AuthenticationGssContinueBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslBody(Bytes);

impl AuthenticationSaslBody {
Expand All @@ -362,6 +366,7 @@ impl AuthenticationSaslBody {
}
}

#[derive(Debug, PartialEq)]
pub struct SaslMechanisms<'a>(&'a [u8]);

impl<'a> FallibleIterator for SaslMechanisms<'a> {
Expand All @@ -387,6 +392,7 @@ impl<'a> FallibleIterator for SaslMechanisms<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslContinueBody(Bytes);

impl AuthenticationSaslContinueBody {
Expand All @@ -396,6 +402,7 @@ impl AuthenticationSaslContinueBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslFinalBody(Bytes);

impl AuthenticationSaslFinalBody {
Expand All @@ -405,6 +412,7 @@ impl AuthenticationSaslFinalBody {
}
}

#[derive(Debug, PartialEq)]
pub struct BackendKeyDataBody {
process_id: i32,
secret_key: i32,
Expand All @@ -422,6 +430,7 @@ impl BackendKeyDataBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CommandCompleteBody {
tag: Bytes,
}
Expand All @@ -433,6 +442,7 @@ impl CommandCompleteBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyDataBody {
storage: Bytes,
}
Expand All @@ -449,6 +459,7 @@ impl CopyDataBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyInResponseBody {
format: u8,
len: u16,
Expand All @@ -470,6 +481,7 @@ impl CopyInResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ColumnFormats<'a> {
buf: &'a [u8],
remaining: u16,
Expand Down Expand Up @@ -503,6 +515,7 @@ impl<'a> FallibleIterator for ColumnFormats<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyOutResponseBody {
format: u8,
len: u16,
Expand All @@ -524,7 +537,7 @@ impl CopyOutResponseBody {
}
}

#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct DataRowBody {
storage: Bytes,
len: u16,
Expand Down Expand Up @@ -599,6 +612,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct ErrorResponseBody {
storage: Bytes,
}
Expand Down Expand Up @@ -657,6 +671,7 @@ impl<'a> ErrorField<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct NoticeResponseBody {
storage: Bytes,
}
Expand All @@ -668,6 +683,7 @@ impl NoticeResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct NotificationResponseBody {
process_id: i32,
channel: Bytes,
Expand All @@ -691,6 +707,7 @@ impl NotificationResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ParameterDescriptionBody {
storage: Bytes,
len: u16,
Expand All @@ -706,6 +723,7 @@ impl ParameterDescriptionBody {
}
}

#[derive(Debug, PartialEq)]
pub struct Parameters<'a> {
buf: &'a [u8],
remaining: u16,
Expand Down Expand Up @@ -739,6 +757,7 @@ impl<'a> FallibleIterator for Parameters<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct ParameterStatusBody {
name: Bytes,
value: Bytes,
Expand All @@ -756,6 +775,7 @@ impl ParameterStatusBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ReadyForQueryBody {
status: u8,
}
Expand All @@ -767,6 +787,7 @@ impl ReadyForQueryBody {
}
}

#[derive(Debug, PartialEq)]
pub struct RowDescriptionBody {
storage: Bytes,
len: u16,
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ postgres-types = { version = "0.2.5", path = "../postgres-types" }
tokio = { version = "1.27", features = ["io-util"] }
tokio-util = { version = "0.7", features = ["codec"] }
rand = "0.8.5"
whoami = "1.4.1"
whoami = "1.4"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
socket2 = { version = "0.5", features = ["all"] }
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ where

match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
m => return Err(Error::unexpected_message(m)),
}

Ok(Portal::new(client, name, statement))
Expand Down
6 changes: 5 additions & 1 deletion tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@ impl Client {
query: &str,
parameter_types: &[Type],
) -> Result<Statement, Error> {
prepare::prepare(&self.inner, query, parameter_types).await
prepare::prepare(&self.inner, query, parameter_types, false).await
}

pub(crate) async fn prepare_unnamed(&self, query: &str) -> Result<Statement, Error> {
prepare::prepare(&self.inner, query, &[], true).await
}

/// Executes a statement, returning a vector of the resulting rows.
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ where
}
}
Some(_) => {}
None => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions tokio-postgres/src/connect_raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,14 @@ where
))
}
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
Some(m) => return Err(Error::unexpected_message(m)),
None => return Err(Error::closed()),
}

match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationOk) => Ok(()),
Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
Some(_) => Err(Error::unexpected_message()),
Some(m) => Err(Error::unexpected_message(m)),
None => Err(Error::closed()),
}
}
Expand Down Expand Up @@ -291,7 +291,7 @@ where
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslContinue(body)) => body,
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
Some(m) => return Err(Error::unexpected_message(m)),
None => return Err(Error::closed()),
};

Expand All @@ -309,7 +309,7 @@ where
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslFinal(body)) => body,
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
Some(m) => return Err(Error::unexpected_message(m)),
None => return Err(Error::closed()),
};

Expand Down Expand Up @@ -348,7 +348,7 @@ where
}
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
Some(m) => return Err(Error::unexpected_message(m)),
None => return Err(Error::closed()),
}
}
Expand Down
3 changes: 2 additions & 1 deletion tokio-postgres/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ where
Some(response) => response,
None => match messages.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
_ => return Err(Error::unexpected_message()),
Some(m) => return Err(Error::unexpected_message(m)),
None => return Err(Error::closed()),
},
};

Expand Down
12 changes: 9 additions & 3 deletions tokio-postgres/src/copy_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ where
let rows = extract_row_affected(&body)?;
return Poll::Ready(Ok(rows));
}
_ => return Poll::Ready(Err(Error::unexpected_message())),
m => return Poll::Ready(Err(Error::unexpected_message(m))),
}
}
}
Expand Down Expand Up @@ -206,13 +206,19 @@ where
.map_err(|_| Error::closed())?;

match responses.next().await? {
Message::ParseComplete => {
match responses.next().await? {
Message::BindComplete => {}
m => return Err(Error::unexpected_message(m)),
}
}
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
m => return Err(Error::unexpected_message(m)),
}

match responses.next().await? {
Message::CopyInResponse(_) => {}
_ => return Err(Error::unexpected_message()),
m => return Err(Error::unexpected_message(m)),
}

Ok(CopyInSink {
Expand Down
10 changes: 7 additions & 3 deletions tokio-postgres/src/copy_out.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;

match responses.next().await? {
Message::ParseComplete => match responses.next().await? {
Message::BindComplete => {}
m => return Err(Error::unexpected_message(m)),
},
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
m => return Err(Error::unexpected_message(m)),
}

match responses.next().await? {
Message::CopyOutResponse(_) => {}
_ => return Err(Error::unexpected_message()),
m => return Err(Error::unexpected_message(m)),
}

Ok(responses)
Expand All @@ -56,7 +60,7 @@ impl Stream for CopyOutStream {
match ready!(this.responses.poll_next(cx)?) {
Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))),
Message::CopyDone => Poll::Ready(None),
_ => Poll::Ready(Some(Err(Error::unexpected_message()))),
m => Poll::Ready(Some(Err(Error::unexpected_message(m)))),
}
}
}
12 changes: 7 additions & 5 deletions tokio-postgres/src/error/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Errors.

use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody, Message};
use std::error::{self, Error as _Error};
use std::fmt;
use std::io;
Expand Down Expand Up @@ -339,7 +339,7 @@ pub enum ErrorPosition {
#[derive(Debug, PartialEq)]
enum Kind {
Io,
UnexpectedMessage,
UnexpectedMessage(Message),
Tls,
ToSql(usize),
FromSql(usize),
Expand Down Expand Up @@ -379,7 +379,9 @@ impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0.kind {
Kind::Io => fmt.write_str("error communicating with the server")?,
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
Kind::UnexpectedMessage(msg) => {
write!(fmt, "unexpected message from server: {:?}", msg)?
}
Kind::Tls => fmt.write_str("error performing TLS handshake")?,
Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?,
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
Expand Down Expand Up @@ -445,8 +447,8 @@ impl Error {
Error::new(Kind::Closed, None)
}

pub(crate) fn unexpected_message() -> Error {
Error::new(Kind::UnexpectedMessage, None)
pub(crate) fn unexpected_message(message: Message) -> Error {
Error::new(Kind::UnexpectedMessage(message), None)
}

#[allow(clippy::needless_pass_by_value)]
Expand Down
Loading