diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 358eba143..29192ccb5 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -132,7 +132,7 @@ pub use crate::error::Error; pub use crate::generic_client::GenericClient; pub use crate::generic_result::GenericResult; pub use crate::portal::Portal; -pub use crate::query::{RowStream, ResultStream}; +pub use crate::query::{ResultStream, RowStream}; pub use crate::row::{Row, SimpleQueryRow}; pub use crate::simple_query::SimpleQueryStream; #[cfg(feature = "runtime")] diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index 4ae6787ab..f2e7858a2 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -98,7 +98,11 @@ impl Stream for SimpleQueryStream { .parse() .unwrap_or(0); let fields = if *this.include_fields_in_complete { - this.fields.clone() + if body.tag().expect("Failed to get tag").starts_with("SELECT") { + this.fields.clone() + } else { + None + } } else { // Reset bool for next grouping *this.include_fields_in_complete = true; diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index dfd95e153..ff23206bf 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -347,7 +347,88 @@ async fn custom_range() { assert_eq!("floatrange", ty.name()); assert_eq!(&Kind::Range(Type::FLOAT8), ty.kind()); } +/// This test check to make sure that empty responses for select queries include the header but not +/// for other query types. +#[tokio::test] +async fn simple_query_select_transaction() { + let client = connect("user=postgres").await; + + let _ = client.simple_query("DROP TABLE sbtest1").await.unwrap(); + let _ = client.simple_query("DROP TABLE sbtest2").await.unwrap(); + let _ = client + .simple_query("CREATE TABLE sbtest1 (id INTEGER, k INTEGER);") + .await + .unwrap(); + let _ = client + .simple_query("CREATE TABLE sbtest2 (id INTEGER, k INTEGER);") + .await + .unwrap(); + + let messages = client + .simple_query( + "INSERT INTO sbtest1 VALUES (1, 2); + INSERT INTO sbtest2 VALUES (1, 2); + SELECT * FROM sbtest1 ORDER BY id; + SELECT k FROM sbtest1 WHERE id = 999; + BEGIN; + UPDATE sbtest1 SET k=id; + UPDATE sbtest2 SET k=id; + END;", + ) + .await + .unwrap(); + + match messages[0] { + SimpleQueryMessage::CommandComplete(CommandCompleteContents { rows: 1, .. }) => {} + _ => panic!("unexpected message or too many rows"), + } + match messages[1] { + SimpleQueryMessage::CommandComplete(CommandCompleteContents { rows: 1, .. }) => {} + _ => panic!("unexpected message or too many rows"), + } + match &messages[2] { + SimpleQueryMessage::Row(row) => { + assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id")); + assert_eq!(row.columns().get(1).map(|c| c.name()), Some("k")); + assert_eq!(row.get(0), Some("1")); + assert_eq!(row.get(1), Some("2")); + } + _ => panic!("unexpected message"), + } + match &messages[3] { + SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {} + _ => panic!("unexpected message or fields are not empty "), + } + match &messages[4] { + SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields, .. }) => { + if let Some(field_vec) = &fields { + assert_eq!((&**field_vec).len(), 1); + assert_eq!("k", (&**field_vec)[0].name()); + } else { + panic!("No data found"); + } + } + _ => panic!("unexpected message"), + } + match &messages[5] { + SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {} + _ => panic!("unexpected message or fields are not empty"), + } + match &messages[6] { + SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {} + _ => panic!("unexpected message or fields are not empty"), + } + match &messages[7] { + SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {} + _ => panic!("unexpected message or fields are not empty"), + } + match &messages[8] { + SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {} + _ => panic!("unexpected message or fields are not empty"), + } + assert_eq!(messages.len(), 9); +} #[tokio::test] async fn simple_query() { let client = connect("user=postgres").await;