Skip to content

Commit

Permalink
Change: remove AsyncSeek trait from snapshot
Browse files Browse the repository at this point in the history
  • Loading branch information
zach-schoenberger committed Nov 17, 2022
1 parent 229bb50 commit 1b1f046
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 40 deletions.
22 changes: 9 additions & 13 deletions openraft/src/core/streaming_state.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::io::SeekFrom;
use std::marker::PhantomData;

use tokio::io::AsyncSeek;
use tokio::io::AsyncSeekExt;
use anyerror::AnyError;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;

Expand All @@ -12,6 +10,7 @@ use crate::ErrorVerb;
use crate::RaftTypeConfig;
use crate::SnapshotId;
use crate::StorageError;
use crate::StorageIOError;

/// The Raft node is streaming in a snapshot from the leader.
pub(crate) struct StreamingState<C: RaftTypeConfig, SD> {
Expand All @@ -26,7 +25,7 @@ pub(crate) struct StreamingState<C: RaftTypeConfig, SD> {
}

impl<C: RaftTypeConfig, SD> StreamingState<C, SD>
where SD: AsyncSeek + AsyncWrite + Unpin
where SD: AsyncWrite + Unpin
{
pub(crate) fn new(snapshot_id: SnapshotId, snapshot_data: Box<SD>) -> Self {
Self {
Expand All @@ -41,16 +40,13 @@ where SD: AsyncSeek + AsyncWrite + Unpin
pub(crate) async fn receive(&mut self, req: InstallSnapshotRequest<C>) -> Result<bool, StorageError<C::NodeId>> {
// TODO: check id?

// Always seek to the target offset if not an exact match.
if req.offset != self.offset {
if let Err(err) = self.snapshot_data.as_mut().seek(SeekFrom::Start(req.offset)).await {
return Err(StorageError::from_io_error(
ErrorSubject::Snapshot(req.meta.signature()),
ErrorVerb::Seek,
err,
));
}
self.offset = req.offset;
let sto_io_err = StorageIOError::new(
ErrorSubject::Snapshot(req.meta.signature()),
ErrorVerb::Write,
AnyError::error(format!("offsets do not match {}:{}", self.offset, req.offset)),
);
return Err(StorageError::IO { source: sto_io_err });
}

// Write the next segment & update offset.
Expand Down
1 change: 1 addition & 0 deletions openraft/src/defensive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub trait DefensiveCheckBase<C: RaftTypeConfig> {
if !self.is_defensive() {
return Ok(());
}

let start = match range.start_bound() {
Bound::Included(i) => Some(*i),
Bound::Excluded(i) => Some(*i + 1),
Expand Down
56 changes: 34 additions & 22 deletions openraft/src/replication/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
//! Replication stream.
use std::io::SeekFrom;
use std::sync::Arc;

use futures::future::FutureExt;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncSeekExt;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
Expand Down Expand Up @@ -714,33 +712,29 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Replication
) -> Result<(), ReplicationError<C::NodeId, C::Node>> {
let err_x = || (ErrorSubject::Snapshot(snapshot.meta.signature()), ErrorVerb::Read);

let end = snapshot.snapshot.seek(SeekFrom::End(0)).await.sto_res(err_x)?;

let mut offset = 0;

let mut buf = Vec::with_capacity(self.config.snapshot_max_chunk_size as usize);

loop {
// Build the RPC.
snapshot.snapshot.seek(SeekFrom::Start(offset)).await.sto_res(err_x)?;
// Build the first RPC.
let mut n_read = snapshot.snapshot.read_buf(&mut buf).await.sto_res(err_x)?;

let n_read = snapshot.snapshot.read_buf(&mut buf).await.sto_res(err_x)?;
let mut done = n_read == 0;

let done = (offset + n_read as u64) == end; // If bytes read == 0, then we're done.
let req = InstallSnapshotRequest {
vote: self.vote,
meta: snapshot.meta.clone(),
offset,
data: Vec::from(&buf[..n_read]),
done,
};
buf.clear();
let mut req = InstallSnapshotRequest {
vote: self.vote,
meta: snapshot.meta.clone(),
offset,
data: Vec::from(&buf[..n_read]),
done,
};
buf.clear();

loop {
// Send the RPC over to the target.
tracing::debug!(
snapshot_size = req.data.len(),
req.offset,
end,
req.done,
"sending snapshot chunk"
);
Expand All @@ -751,7 +745,9 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Replication
self.config.send_snapshot_timeout()
};

let res = timeout(snap_timeout, self.network.send_install_snapshot(req)).await;
// TODO should not clone the request. If possible we should return the failed request or better yet pass the
// req as a reference
let res = timeout(snap_timeout, self.network.send_install_snapshot(req.clone())).await;

let res = match res {
Ok(outer_res) => match outer_res {
Expand All @@ -761,6 +757,7 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Replication

// Sleep a short time otherwise in test environment it is a dead-loop that never yields.
// Because network implementation does not yield.
#[cfg(test)]
sleep(Duration::from_millis(10)).await;
continue;
}
Expand All @@ -770,6 +767,7 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Replication

// Sleep a short time otherwise in test environment it is a dead-loop that never yields.
// Because network implementation does not yield.
#[cfg(test)]
sleep(Duration::from_millis(10)).await;
continue;
}
Expand All @@ -796,11 +794,25 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> Replication
return Ok(());
}

// Everything is good, so update offset for sending the next chunk.
offset += n_read as u64;

// Check raft channel to ensure we are staying up-to-date, then loop.
self.try_drain_raft_rx().await?;

// Build the next RPC.
// update offset for sending the next chunk.
offset += n_read as u64;

n_read = snapshot.snapshot.read_buf(&mut buf).await.sto_res(err_x)?;

done = n_read == 0;

req = InstallSnapshotRequest {
vote: self.vote,
meta: snapshot.meta.clone(),
offset,
data: Vec::from(&buf[..n_read]),
done,
};
buf.clear();
}
}
}
7 changes: 3 additions & 4 deletions openraft/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use async_trait::async_trait;
pub use helper::StorageHelper;
pub use snapshot_signature::SnapshotSignature;
use tokio::io::AsyncRead;
use tokio::io::AsyncSeek;
use tokio::io::AsyncWrite;

use crate::defensive::check_range_matches_entries;
Expand Down Expand Up @@ -77,7 +76,7 @@ pub struct Snapshot<NID, N, S>
where
NID: NodeId,
N: Node,
S: AsyncRead + AsyncSeek + Send + Unpin + 'static,
S: AsyncRead + Send + Unpin + 'static,
{
/// metadata of a snapshot
pub meta: SnapshotMeta<NID, N>,
Expand Down Expand Up @@ -164,7 +163,7 @@ where C: RaftTypeConfig
pub trait RaftSnapshotBuilder<C, SD>: Send + Sync + 'static
where
C: RaftTypeConfig,
SD: AsyncRead + AsyncWrite + AsyncSeek + Send + Sync + Unpin + 'static,
SD: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
{
/// Build snapshot
///
Expand Down Expand Up @@ -201,7 +200,7 @@ where C: RaftTypeConfig
///
/// See the [storage chapter of the guide](https://datafuselabs.github.io/openraft/getting-started.html#implement-raftstorage)
/// for details on where and how this is used.
type SnapshotData: AsyncRead + AsyncWrite + AsyncSeek + Send + Sync + Unpin + 'static;
type SnapshotData: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static;

/// Log reader type.
type LogReader: RaftLogReader<C>;
Expand Down
7 changes: 6 additions & 1 deletion openraft/tests/snapshot/t20_api_install_snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,13 @@ async fn snapshot_arguments() -> Result<()> {
let mut req = req0.clone();
req.offset = 8;
req.meta.snapshot_id = "ss2".into();
n.0.install_snapshot(req).await?;
let res = n.0.install_snapshot(req).await;
assert_eq!(
r#"when Write Snapshot(SnapshotSignature { last_log_id: Some(LogId { leader_id: LeaderId { term: 1, node_id: 0 }, index: 0 }), last_membership_log_id: None, snapshot_id: "ss2" }): offsets do not match 6:8"#,
res.unwrap_err().to_string()
);
}

Ok(())
}

Expand Down

0 comments on commit 1b1f046

Please sign in to comment.