Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text protocol based query method (#14) #1079

Closed
Closed
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: sfackler/actions/rustup@master
- run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT
with:
version: 1.67.0
- run: echo "::set-output name=version::$(rustc --version)"
id: rust-version
- run: rustup target add wasm32-unknown-unknown
- uses: actions/cache@v3
Expand Down
23 changes: 22 additions & 1 deletion postgres-protocol/src/message/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ impl Header {
}

/// An enum representing Postgres backend messages.
#[derive(Debug, PartialEq)]
#[non_exhaustive]
pub enum Message {
AuthenticationCleartextPassword,
Expand Down Expand Up @@ -333,6 +334,7 @@ impl Read for Buffer {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationMd5PasswordBody {
salt: [u8; 4],
}
Expand All @@ -344,6 +346,7 @@ impl AuthenticationMd5PasswordBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationGssContinueBody(Bytes);

impl AuthenticationGssContinueBody {
Expand All @@ -353,6 +356,7 @@ impl AuthenticationGssContinueBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslBody(Bytes);

impl AuthenticationSaslBody {
Expand All @@ -362,6 +366,7 @@ impl AuthenticationSaslBody {
}
}

#[derive(Debug, PartialEq)]
pub struct SaslMechanisms<'a>(&'a [u8]);

impl<'a> FallibleIterator for SaslMechanisms<'a> {
Expand All @@ -387,6 +392,7 @@ impl<'a> FallibleIterator for SaslMechanisms<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslContinueBody(Bytes);

impl AuthenticationSaslContinueBody {
Expand All @@ -396,6 +402,7 @@ impl AuthenticationSaslContinueBody {
}
}

#[derive(Debug, PartialEq)]
pub struct AuthenticationSaslFinalBody(Bytes);

impl AuthenticationSaslFinalBody {
Expand All @@ -405,6 +412,7 @@ impl AuthenticationSaslFinalBody {
}
}

#[derive(Debug, PartialEq)]
pub struct BackendKeyDataBody {
process_id: i32,
secret_key: i32,
Expand All @@ -422,6 +430,7 @@ impl BackendKeyDataBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CommandCompleteBody {
tag: Bytes,
}
Expand All @@ -433,6 +442,7 @@ impl CommandCompleteBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyDataBody {
storage: Bytes,
}
Expand All @@ -449,6 +459,7 @@ impl CopyDataBody {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyInResponseBody {
format: u8,
len: u16,
Expand All @@ -470,6 +481,7 @@ impl CopyInResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ColumnFormats<'a> {
buf: &'a [u8],
remaining: u16,
Expand Down Expand Up @@ -503,6 +515,7 @@ impl<'a> FallibleIterator for ColumnFormats<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct CopyOutResponseBody {
format: u8,
len: u16,
Expand All @@ -524,7 +537,7 @@ impl CopyOutResponseBody {
}
}

#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct DataRowBody {
storage: Bytes,
len: u16,
Expand Down Expand Up @@ -599,6 +612,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct ErrorResponseBody {
storage: Bytes,
}
Expand Down Expand Up @@ -657,6 +671,7 @@ impl<'a> ErrorField<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct NoticeResponseBody {
storage: Bytes,
}
Expand All @@ -668,6 +683,7 @@ impl NoticeResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct NotificationResponseBody {
process_id: i32,
channel: Bytes,
Expand All @@ -691,6 +707,7 @@ impl NotificationResponseBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ParameterDescriptionBody {
storage: Bytes,
len: u16,
Expand All @@ -706,6 +723,7 @@ impl ParameterDescriptionBody {
}
}

#[derive(Debug, PartialEq)]
pub struct Parameters<'a> {
buf: &'a [u8],
remaining: u16,
Expand Down Expand Up @@ -739,6 +757,7 @@ impl<'a> FallibleIterator for Parameters<'a> {
}
}

#[derive(Debug, PartialEq)]
pub struct ParameterStatusBody {
name: Bytes,
value: Bytes,
Expand All @@ -756,6 +775,7 @@ impl ParameterStatusBody {
}
}

#[derive(Debug, PartialEq)]
pub struct ReadyForQueryBody {
status: u8,
}
Expand All @@ -767,6 +787,7 @@ impl ReadyForQueryBody {
}
}

#[derive(Debug, PartialEq)]
pub struct RowDescriptionBody {
storage: Bytes,
len: u16,
Expand Down
18 changes: 17 additions & 1 deletion postgres-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,22 @@ impl WrongType {
}
}

/// An error indicating that a as_text conversion was attempted on a binary
/// result.
#[derive(Debug)]
pub struct WrongFormat {}

impl Error for WrongFormat {}

impl fmt::Display for WrongFormat {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
fmt,
"cannot read column as text while it is in binary format"
)
}
}

/// A trait for types that can be created from a Postgres value.
///
/// # Types
Expand Down Expand Up @@ -893,7 +909,7 @@ pub trait ToSql: fmt::Debug {
/// Supported Postgres message format types
///
/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8`
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Format {
/// Text format (UTF-8)
Text,
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ postgres-types = { version = "0.2.5", path = "../postgres-types" }
tokio = { version = "1.27", features = ["io-util"] }
tokio-util = { version = "0.7", features = ["codec"] }
rand = "0.8.5"
whoami = "1.4.1"
whoami = "1.4"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
socket2 = { version = "0.5", features = ["all"] }
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ where

match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
m => return Err(Error::unexpected_message(m)),
}

Ok(Portal::new(client, name, statement))
Expand Down
77 changes: 24 additions & 53 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::connection::{Request, RequestMessages};
use crate::copy_out::CopyOutStream;
#[cfg(feature = "runtime")]
use crate::keepalive::KeepaliveConfig;
use crate::query::RowStream;
use crate::query::{RowStream, };
use crate::simple_query::SimpleQueryStream;
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
Expand Down Expand Up @@ -61,29 +61,9 @@ impl Responses {
}
}

/// A cache of type info and prepared statements for fetching type info
/// (corresponding to the queries in the [prepare](prepare) module).
#[derive(Default)]
struct CachedTypeInfo {
/// A statement for basic information for a type from its
/// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its
/// fallback).
typeinfo: Option<Statement>,
/// A statement for getting information for a composite type from its OID.
/// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY).
typeinfo_composite: Option<Statement>,
/// A statement for getting information for a composite type from its OID.
/// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or
/// its fallback).
typeinfo_enum: Option<Statement>,

/// Cache of types already looked up.
types: HashMap<Oid, Type>,
}

pub struct InnerClient {
sender: mpsc::UnboundedSender<Request>,
cached_typeinfo: Mutex<CachedTypeInfo>,
cached_typeinfo: Mutex<HashMap<Oid, Type>>,

/// A buffer to use when writing out postgres commands.
buffer: Mutex<BytesMut>,
Expand All @@ -103,40 +83,12 @@ impl InnerClient {
})
}

pub fn typeinfo(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo.clone()
}

pub fn set_typeinfo(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo = Some(statement.clone());
}

pub fn typeinfo_composite(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_composite.clone()
}

pub fn set_typeinfo_composite(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone());
}

pub fn typeinfo_enum(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_enum.clone()
}

pub fn set_typeinfo_enum(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone());
}

pub fn type_(&self, oid: Oid) -> Option<Type> {
self.cached_typeinfo.lock().types.get(&oid).cloned()
}

pub fn set_type(&self, oid: Oid, type_: &Type) {
self.cached_typeinfo.lock().types.insert(oid, type_.clone());
self.cached_typeinfo.lock().get(&oid).cloned()
}

pub fn clear_type_cache(&self) {
self.cached_typeinfo.lock().types.clear();
self.cached_typeinfo.lock().clear();
}

/// Call the given function with a buffer to be used when writing out
Expand Down Expand Up @@ -231,7 +183,11 @@ impl Client {
query: &str,
parameter_types: &[Type],
) -> Result<Statement, Error> {
prepare::prepare(&self.inner, query, parameter_types).await
prepare::prepare(&self.inner, query, parameter_types, false).await
}

pub(crate) async fn prepare_unnamed(&self, query: &str) -> Result<Statement, Error> {
prepare::prepare(&self.inner, query, &[], true).await
}

/// Executes a statement, returning a vector of the resulting rows.
Expand Down Expand Up @@ -368,6 +324,21 @@ impl Client {
query::query(&self.inner, statement, params).await
}

/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<'a, S, I>(
&self,
query: &str,
params: I,
) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
query::query_txt(&self.inner, query, params).await
}

/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
Expand Down
13 changes: 12 additions & 1 deletion tokio-postgres/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages {
}
}

pub struct PostgresCodec;
pub struct PostgresCodec {
pub max_message_size: Option<usize>,
}

impl Encoder<FrontendMessage> for PostgresCodec {
type Error = io::Error;
Expand Down Expand Up @@ -64,6 +66,15 @@ impl Decoder for PostgresCodec {
break;
}

if let Some(max) = self.max_message_size {
if len > max {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"message too large",
));
}
}

match header.tag() {
backend::NOTICE_RESPONSE_TAG
| backend::NOTIFICATION_RESPONSE_TAG
Expand Down
Loading
Loading