Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replication stream/sink splitting #2

Open
wants to merge 1 commit into
base: replication2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 158 additions & 1 deletion tokio-postgres/src/copy_both.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,18 @@ enum SinkState {
}

pin_project! {
/// A sink for `COPY ... FROM STDIN` query data.
/// A sink & stream for `CopyBoth` replication messages
///
/// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is
/// not, the copy will be aborted.
///
/// The duplex can be split into the separate sink and stream with the [`split`] method. When
/// using this, they must be re-joined before finishing in order to properly complete the copy.
///
/// Both the implementation of [`Stream`] and [`Sink`] provide access to the bytes wrapped
/// inside of the `CopyData` wrapper.
///
/// [`split`]: Self::split
pub struct CopyBothDuplex<T> {
#[pin]
sender: mpsc::Sender<CopyBothMessage>,
Expand Down Expand Up @@ -146,6 +154,53 @@ where
pub async fn finish(mut self: Pin<&mut Self>) -> Result<u64, Error> {
future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await
}

/// Splits the streams into distinct [`Sink`] and [`Stream`] components
///
/// Please note that there must be an eventual call to [`join`] the two components in order to
/// properly close the connection with [`finish`]; no corresponding method exists for the two
/// halves alone.
///
/// [`join`]: Self::join
/// [`finish`]: Self::finish
pub fn split(self) -> (Sender<T>, Receiver) {
let send = Sender {
sender: self.sender,
buf: self.buf,
state: self.state,
marker: PhantomData,
closed: false,
};

let recv = Receiver {
responses: self.responses,
};

(send, recv)
}

/// Joins the two halves of a `CopyBothDuplex` after a call to [`split`]
///
/// Note: We do not check that the sender and recevier originated from the same
/// [`CopyBothDuplex`]. If they did not, unexpected behavior *will* occur.
///
/// ## Panics
///
/// If the sender has already been closed, this function will panic.
///
/// [`split`]: Self::split
pub fn join(send: Sender<T>, recv: Receiver) -> Self {
assert!(!send.closed);

CopyBothDuplex {
sender: send.sender,
responses: recv.responses,
buf: send.buf,
state: send.state,
_p: PhantomPinned,
_p2: PhantomData,
}
}
}

impl<T> Stream for CopyBothDuplex<T> {
Expand All @@ -157,6 +212,7 @@ impl<T> Stream for CopyBothDuplex<T> {
match ready!(this.responses.poll_next(cx)?) {
Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))),
Message::CopyDone => Poll::Ready(None),
Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))),
_ => Poll::Ready(Some(Err(Error::unexpected_message()))),
}
}
Expand Down Expand Up @@ -220,6 +276,107 @@ where
}
}

pin_project! {
/// The receiving half of a [`CopyBothDuplex`]
///
/// Receiving the next message is done through the [`Stream`] implementation.
pub struct Receiver {
responses: Responses,
}
}

pin_project! {
/// The sending half of a [`CopyBothDuplex`]
///
/// Sending each message is done through the [`Sink`] implementation.
pub struct Sender<T> {
#[pin]
sender: mpsc::Sender<CopyBothMessage>,
buf: BytesMut,
state: SinkState,
marker: PhantomData<T>,
// True iff the sink has been closed. Causes further operations to panic.
closed: bool,
}
}

impl Stream for Receiver {
type Item = Result<Bytes, Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();

match ready!(this.responses.poll_next(cx)?) {
Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))),
Message::CopyDone => Poll::Ready(None),
Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))),
_ => Poll::Ready(Some(Err(Error::unexpected_message()))),
}
}
}

impl<T> Sink<T> for Sender<T>
where
T: Buf + 'static + Send,
{
type Error = Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.project()
.sender
.poll_ready(cx)
.map_err(|_| Error::closed())
}

fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
assert!(!self.closed);

let this = self.project();

let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
if this.buf.is_empty() {
Box::new(item)
} else {
Box::new(this.buf.split().freeze().chain(item))
}
} else {
this.buf.put(item);
if this.buf.len() > 4096 {
Box::new(this.buf.split().freeze())
} else {
return Ok(());
}
};

let data = CopyData::new(data).map_err(Error::encode)?;
this.sender
.start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data)))
.map_err(|_| Error::closed())
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let mut this = self.project();

if !this.buf.is_empty() {
ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
let data = CopyData::new(data).map_err(Error::encode)?;
this.sender
.as_mut()
.start_send(CopyBothMessage::Message(FrontendMessage::CopyData(data)))
.map_err(|_| Error::closed())?;
}

this.sender.poll_flush(cx).map_err(|_| Error::closed())
}

// Closing the sink "normally" will just abort the copy.
fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.closed = true;
Poll::Ready(Ok(()))
}
}

pub async fn copy_both_simple<T>(
client: &InnerClient,
query: &str,
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ pub use crate::cancel_token::CancelToken;
pub use crate::client::Client;
pub use crate::config::Config;
pub use crate::connection::Connection;
pub use crate::copy_both::CopyBothDuplex;
pub use crate::copy_both::{CopyBothDuplex, Receiver as CopyBothStream, Sender as CopyBothSink};
pub use crate::copy_in::CopyInSink;
pub use crate::copy_out::CopyOutStream;
use crate::error::DbError;
Expand Down