Skip to content

Commit 0ba4058

Browse files
committed
Transactions aren't executed with "psql -c"
The transaction did not appear to execute when running a set of SQL commands using psql commands with the '-c' option, such as "SELECT * FROM t; BEGIN; UPDATE t SET k = id; END;". This issue was caused by the tokio-postgres crate adding column information to all EmptyQueryResponses for every SQL query. The resolution for this issue was to only add the column information to 'SELECT' queries.
1 parent 13206f6 commit 0ba4058

File tree

3 files changed

+250
-2
lines changed

3 files changed

+250
-2
lines changed

tokio-postgres/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ pub use crate::error::Error;
132132
pub use crate::generic_client::GenericClient;
133133
pub use crate::generic_result::GenericResult;
134134
pub use crate::portal::Portal;
135-
pub use crate::query::{RowStream, ResultStream};
135+
pub use crate::query::{ResultStream, RowStream};
136136
pub use crate::row::{Row, SimpleQueryRow};
137137
pub use crate::simple_query::SimpleQueryStream;
138138
#[cfg(feature = "runtime")]

tokio-postgres/src/simple_query.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,12 @@ impl Stream for SimpleQueryStream {
9898
.parse()
9999
.unwrap_or(0);
100100
let fields = if *this.include_fields_in_complete {
101-
this.fields.clone()
101+
let _tag = body.tag().expect("Failed to get tag");
102+
if _tag.starts_with("SELECT") {
103+
this.fields.clone()
104+
} else {
105+
None
106+
}
102107
} else {
103108
// Reset bool for next grouping
104109
*this.include_fields_in_complete = true;

tokio-postgres/tests/test/main.rs

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,250 @@ async fn custom_range() {
347347
assert_eq!("floatrange", ty.name());
348348
assert_eq!(&Kind::Range(Type::FLOAT8), ty.kind());
349349
}
350+
/// This test check to make sure that empty responses for select queries include the header but not
351+
/// for other query types.
352+
#[tokio::test]
353+
async fn simple_query_select_transaction() {
354+
let client = connect("user=postgres").await;
355+
356+
// The following statements configure the tables for the test.
357+
let _ = client.simple_query("DROP TABLE sbtest1").await;
358+
let _ = client.simple_query("DROP TABLE sbtest2").await;
359+
let _ = client
360+
.simple_query("CREATE TABLE sbtest1 (id INTEGER, k INTEGER);")
361+
.await;
362+
let _ = client
363+
.simple_query("CREATE TABLE sbtest2 (id INTEGER, k INTEGER);")
364+
.await;
365+
/// This test check to make sure that empty responses for select queries include the header but not
366+
/// for other query types.
367+
#[tokio::test]
368+
async fn simple_query_select_transaction() {
369+
let client = connect("user=postgres").await;
370+
371+
match client.simple_query("DROP TABLE sbtest1").await {
372+
Ok(_) => {
373+
println!("Query executed successfully");
374+
// Your code here when the query succeeds
375+
}
376+
Err(e) => {
377+
eprintln!("Error executing query: {}", e);
378+
// Your code here when the query fails
379+
}
380+
}
381+
382+
match client.simple_query("DROP TABLE sbtest2").await {
383+
Ok(_) => {
384+
println!("Query executed successfully");
385+
// Your code here when the query succeeds
386+
}
387+
Err(e) => {
388+
eprintln!("Error executing query: {}", e);
389+
// Your code here when the query fails
390+
}
391+
}
392+
393+
match client
394+
.simple_query("CREATE TABLE sbtest1 (id INTEGER, k INTEGER);")
395+
.await
396+
{
397+
Ok(_) => {
398+
println!("Query executed successfully");
399+
// Your code here when the query succeeds
400+
}
401+
Err(e) => {
402+
eprintln!("Error executing query: {}", e);
403+
// Your code here when the query fails
404+
}
405+
}
406+
407+
let messages = client
408+
.simple_query(
409+
"INSERT INTO sbtest1 VALUES (1, 2);
410+
INSERT INTO sbtest2 VALUES (1, 2);
411+
SELECT * FROM sbtest1 ORDER BY id;
412+
SELECT k FROM sbtest1 WHERE id = 999;
413+
BEGIN;
414+
UPDATE sbtest1 SET k=id;
415+
UPDATE sbtest2 SET k=id;
416+
END;",
417+
)
418+
.await
419+
.unwrap();
420+
421+
match messages[0] {
422+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { rows: 1, .. }) => {}
423+
_ => panic!("unexpected message"),
424+
}
425+
match messages[1] {
426+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { rows: 1, .. }) => {}
427+
_ => panic!("unexpected message"),
428+
}
429+
match &messages[2] {
430+
SimpleQueryMessage::Row(row) => {
431+
assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id"));
432+
assert_eq!(row.columns().get(1).map(|c| c.name()), Some("k"));
433+
assert_eq!(row.get(0), Some("1"));
434+
assert_eq!(row.get(1), Some("2"));
435+
}
436+
_ => panic!("unexpected message"),
437+
}
438+
match &messages[3] {
439+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
440+
if f.is_some() {
441+
panic!("Unexpected data found");
442+
}
443+
}
444+
_ => panic!("unexpected message"),
445+
}
446+
match &messages[4] {
447+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
448+
if let Some(field_vec) = &f {
449+
assert_eq!((&**field_vec).len(), 1);
450+
assert_eq!("k", (&**field_vec)[0].name());
451+
} else {
452+
panic!("No data found");
453+
}
454+
}
455+
_ => panic!("unexpected message"),
456+
}
457+
match &messages[5] {
458+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
459+
if f.is_some() {
460+
panic!("Unexpected data found");
461+
}
462+
}
463+
_ => panic!("unexpected message"),
464+
}
465+
match &messages[6] {
466+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
467+
if f.is_some() {
468+
panic!("Unexpected data found");
469+
}
470+
}
471+
_ => panic!("unexpected message"),
472+
}
473+
match &messages[7] {
474+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
475+
if f.is_some() {
476+
panic!("Unexpected data found");
477+
}
478+
}
479+
_ => panic!("unexpected message"),
480+
}
481+
match &messages[8] {
482+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
483+
if f.is_some() {
484+
panic!("Unexpected data found");
485+
}
486+
}
487+
_ => panic!("unexpected message"),
488+
}
489+
490+
assert_eq!(messages.len(), 9);
491+
}
492+
match client
493+
.simple_query("CREATE TABLE sbtest2 (id INTEGER, k INTEGER);")
494+
.await
495+
{
496+
Ok(_) => {
497+
println!("Query executed successfully");
498+
// Your code here when the query succeeds
499+
}
500+
Err(e) => {
501+
eprintln!("Error executing query: {}", e);
502+
// Your code here when the query fails
503+
}
504+
}
350505

506+
let messages = client
507+
.simple_query(
508+
"INSERT INTO sbtest1 VALUES (1, 2);
509+
INSERT INTO sbtest2 VALUES (1, 2);
510+
SELECT * FROM sbtest1 ORDER BY id;
511+
SELECT k FROM sbtest1 WHERE id = 999;
512+
BEGIN;
513+
UPDATE sbtest1 SET k=id;
514+
UPDATE sbtest2 SET k=id;
515+
END;",
516+
)
517+
.await
518+
.unwrap();
519+
520+
println!("{:?}", messages);
521+
let _m = format!("{:?}", messages);
522+
523+
match messages[0] {
524+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { rows: 1, .. }) => {}
525+
_ => panic!("unexpected message"),
526+
}
527+
match messages[1] {
528+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { rows: 1, .. }) => {}
529+
_ => panic!("unexpected message"),
530+
}
531+
match &messages[2] {
532+
SimpleQueryMessage::Row(row) => {
533+
assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id"));
534+
assert_eq!(row.columns().get(1).map(|c| c.name()), Some("k"));
535+
assert_eq!(row.get(0), Some("1"));
536+
assert_eq!(row.get(1), Some("2"));
537+
}
538+
_ => panic!("unexpected message"),
539+
}
540+
match &messages[3] {
541+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
542+
if f.is_some() {
543+
panic!("Unexpected data found");
544+
}
545+
}
546+
_ => panic!("unexpected message"),
547+
}
548+
match &messages[4] {
549+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
550+
if let Some(field_vec) = &f {
551+
assert_eq!((&**field_vec).len(), 1);
552+
assert_eq!("k", (&**field_vec)[0].name());
553+
} else {
554+
panic!("No data found");
555+
}
556+
}
557+
_ => panic!("unexpected message"),
558+
}
559+
match &messages[5] {
560+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
561+
if f.is_some() {
562+
panic!("Unexpected data found");
563+
}
564+
}
565+
_ => panic!("unexpected message"),
566+
}
567+
match &messages[6] {
568+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
569+
if f.is_some() {
570+
panic!("Unexpected data found");
571+
}
572+
}
573+
_ => panic!("unexpected message"),
574+
}
575+
match &messages[7] {
576+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
577+
if f.is_some() {
578+
panic!("Unexpected data found");
579+
}
580+
}
581+
_ => panic!("unexpected message"),
582+
}
583+
match &messages[8] {
584+
SimpleQueryMessage::CommandComplete(CommandCompleteContents { fields: f, .. }) => {
585+
if f.is_some() {
586+
panic!("Unexpected data found");
587+
}
588+
}
589+
_ => panic!("unexpected message"),
590+
}
591+
592+
assert_eq!(messages.len(), 9);
593+
}
351594
#[tokio::test]
352595
async fn simple_query() {
353596
let client = connect("user=postgres").await;

0 commit comments

Comments
 (0)