Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tokio-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::generic_result::GenericResult;
pub use crate::portal::Portal;
pub use crate::query::{RowStream, ResultStream};
pub use crate::query::{ResultStream, RowStream};
pub use crate::row::{Row, SimpleQueryRow};
pub use crate::simple_query::SimpleQueryStream;
#[cfg(feature = "runtime")]
Expand Down
6 changes: 5 additions & 1 deletion tokio-postgres/src/simple_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ impl Stream for SimpleQueryStream {
.parse()
.unwrap_or(0);
let fields = if *this.include_fields_in_complete {
this.fields.clone()
if body.tag().expect("Failed to get tag").starts_with("SELECT") {
this.fields.clone()
} else {
None
}
} else {
// Reset bool for next grouping
*this.include_fields_in_complete = true;
Expand Down
81 changes: 81 additions & 0 deletions tokio-postgres/tests/test/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,88 @@ async fn custom_range() {
assert_eq!("floatrange", ty.name());
assert_eq!(&Kind::Range(Type::FLOAT8), ty.kind());
}
/// This test check to make sure that empty responses for select queries include the header but not
/// for other query types.
#[tokio::test]
async fn simple_query_select_transaction() {
let client = connect("user=postgres").await;

let _ = client.simple_query("DROP TABLE sbtest1").await.unwrap();
let _ = client.simple_query("DROP TABLE sbtest2").await.unwrap();
let _ = client
.simple_query("CREATE TABLE sbtest1 (id INTEGER, k INTEGER);")
.await
.unwrap();
let _ = client
.simple_query("CREATE TABLE sbtest2 (id INTEGER, k INTEGER);")
.await
.unwrap();

let messages = client
.simple_query(
"INSERT INTO sbtest1 VALUES (1, 2);
INSERT INTO sbtest2 VALUES (1, 2);
SELECT * FROM sbtest1 ORDER BY id;
SELECT k FROM sbtest1 WHERE id = 999;
BEGIN;
UPDATE sbtest1 SET k=id;
UPDATE sbtest2 SET k=id;
END;",
)
.await
.unwrap();

match messages[0] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { rows: 1, .. }) => {}
_ => panic!("unexpected message or too many rows"),
}
match messages[1] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { rows: 1, .. }) => {}
_ => panic!("unexpected message or too many rows"),
}
match &messages[2] {
SimpleQueryMessage::Row(row) => {
assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id"));
assert_eq!(row.columns().get(1).map(|c| c.name()), Some("k"));
assert_eq!(row.get(0), Some("1"));
assert_eq!(row.get(1), Some("2"));
}
_ => panic!("unexpected message"),
}
match &messages[3] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {}
_ => panic!("unexpected message or fields are not empty "),
}
match &messages[4] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields, .. }) => {
if let Some(field_vec) = &fields {
assert_eq!((&**field_vec).len(), 1);
assert_eq!("k", (&**field_vec)[0].name());
} else {
panic!("No data found");
}
}
_ => panic!("unexpected message"),
}
match &messages[5] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {}
_ => panic!("unexpected message or fields are not empty"),
}
match &messages[6] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {}
_ => panic!("unexpected message or fields are not empty"),
}
match &messages[7] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {}
_ => panic!("unexpected message or fields are not empty"),
}
match &messages[8] {
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: None, .. }) => {}
_ => panic!("unexpected message or fields are not empty"),
}

assert_eq!(messages.len(), 9);
}
#[tokio::test]
async fn simple_query() {
let client = connect("user=postgres").await;
Expand Down