Skip to content

Commit

Permalink
add copy_both_simple method
Browse files Browse the repository at this point in the history
Signed-off-by: Petros Angelatos <[email protected]>
  • Loading branch information
petrosagg committed Aug 21, 2024
1 parent e7ecfce commit c125179
Show file tree
Hide file tree
Showing 7 changed files with 576 additions and 3 deletions.
34 changes: 34 additions & 0 deletions postgres-protocol/src/message/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub const DATA_ROW_TAG: u8 = b'D';
pub const ERROR_RESPONSE_TAG: u8 = b'E';
pub const COPY_IN_RESPONSE_TAG: u8 = b'G';
pub const COPY_OUT_RESPONSE_TAG: u8 = b'H';
pub const COPY_BOTH_RESPONSE_TAG: u8 = b'W';
pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I';
pub const BACKEND_KEY_DATA_TAG: u8 = b'K';
pub const NO_DATA_TAG: u8 = b'n';
Expand Down Expand Up @@ -93,6 +94,7 @@ pub enum Message {
CopyDone,
CopyInResponse(CopyInResponseBody),
CopyOutResponse(CopyOutResponseBody),
CopyBothResponse(CopyBothResponseBody),
DataRow(DataRowBody),
EmptyQueryResponse,
ErrorResponse(ErrorResponseBody),
Expand Down Expand Up @@ -190,6 +192,16 @@ impl Message {
storage,
})
}
COPY_BOTH_RESPONSE_TAG => {
let format = buf.read_u8()?;
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::CopyBothResponse(CopyBothResponseBody {
format,
len,
storage,
})
}
EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse,
BACKEND_KEY_DATA_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
Expand Down Expand Up @@ -524,6 +536,28 @@ impl CopyOutResponseBody {
}
}

#[derive(Debug, Clone)]
pub struct CopyBothResponseBody {
format: u8,
len: u16,
storage: Bytes,
}

impl CopyBothResponseBody {
#[inline]
pub fn format(&self) -> u8 {
self.format
}

#[inline]
pub fn column_formats(&self) -> ColumnFormats<'_> {
ColumnFormats {
remaining: self.len,
buf: &self.storage,
}
}
}

#[derive(Debug, Clone)]
pub struct DataRowBody {
storage: Bytes,
Expand Down
48 changes: 45 additions & 3 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::codec::BackendMessages;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::SslMode;
use crate::connection::{Request, RequestMessages};
use crate::copy_both::{CopyBothDuplex, CopyBothReceiver};
use crate::copy_out::CopyOutStream;
#[cfg(feature = "runtime")]
use crate::keepalive::KeepaliveConfig;
Expand All @@ -13,8 +14,9 @@ use crate::types::{Oid, ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{
copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error,
Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
copy_both, copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken,
CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction,
TransactionBuilder,
};
use bytes::{Buf, BytesMut};
use fallible_iterator::FallibleIterator;
Expand All @@ -41,6 +43,11 @@ pub struct Responses {
cur: BackendMessages,
}

pub struct CopyBothHandles {
pub(crate) stream_receiver: mpsc::Receiver<Result<Message, Error>>,
pub(crate) sink_sender: mpsc::Sender<FrontendMessage>,
}

impl Responses {
pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
loop {
Expand Down Expand Up @@ -115,6 +122,32 @@ impl InnerClient {
})
}

pub fn start_copy_both(&self) -> Result<CopyBothHandles, Error> {
let (sender, receiver) = mpsc::channel(1);
let (stream_sender, stream_receiver) = mpsc::channel(0);
let (sink_sender, sink_receiver) = mpsc::channel(0);

let responses = Responses {
receiver,
cur: BackendMessages::empty(),
};
let messages = RequestMessages::CopyBoth(CopyBothReceiver::new(
responses,
sink_receiver,
stream_sender,
));

let request = Request { messages, sender };
self.sender
.unbounded_send(request)
.map_err(|_| Error::closed())?;

Ok(CopyBothHandles {
stream_receiver,
sink_sender,
})
}

pub fn typeinfo(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo.clone()
}
Expand Down Expand Up @@ -505,6 +538,15 @@ impl Client {
copy_out::copy_out(self.inner(), statement).await
}

/// Executes a CopyBoth query, returning a combined Stream+Sink type to read and write copy
/// data.
pub async fn copy_both_simple<T>(&self, query: &str) -> Result<CopyBothDuplex<T>, Error>
where
T: Buf + 'static + Send,
{
copy_both::copy_both_simple(self.inner(), query).await
}

/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
Expand Down
20 changes: 20 additions & 0 deletions tokio-postgres/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::copy_both::CopyBothReceiver;
use crate::copy_in::CopyInReceiver;
use crate::error::DbError;
use crate::maybe_tls_stream::MaybeTlsStream;
Expand All @@ -20,6 +21,7 @@ use tokio_util::codec::Framed;
pub enum RequestMessages {
Single(FrontendMessage),
CopyIn(CopyInReceiver),
CopyBoth(CopyBothReceiver),
}

pub struct Request {
Expand Down Expand Up @@ -258,6 +260,24 @@ where
.map_err(Error::io)?;
self.pending_request = Some(RequestMessages::CopyIn(receiver));
}
RequestMessages::CopyBoth(mut receiver) => {
let message = match receiver.poll_next_unpin(cx) {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => {
trace!("poll_write: finished copy_both request");
continue;
}
Poll::Pending => {
trace!("poll_write: waiting on copy_both stream");
self.pending_request = Some(RequestMessages::CopyBoth(receiver));
return Ok(true);
}
};
Pin::new(&mut self.stream)
.start_send(message)
.map_err(Error::io)?;
self.pending_request = Some(RequestMessages::CopyBoth(receiver));
}
}
}
}
Expand Down
Loading

0 comments on commit c125179

Please sign in to comment.