Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a session::tokio module with a convenience function for executing a session #91

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ displaydoc = "0.2"
tokio = { version = "1", features = ["rt", "sync", "time", "macros"] }
rand = "0.8"
digest = "0.10"
manul = { path = "../manul", features = ["dev"] }
manul = { path = "../manul", features = ["dev", "tokio"] }
test-log = { version = "0.2", features = ["trace", "color"] }
143 changes: 7 additions & 136 deletions examples/tests/async_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use alloc::collections::{BTreeMap, BTreeSet};
use manul::{
dev::{BinaryFormat, TestSessionParams, TestSigner},
protocol::Protocol,
session::{CanFinalize, LocalError, Message, RoundOutcome, Session, SessionId, SessionParameters, SessionReport},
session::{
tokio::{run_session, MessageIn, MessageOut},
Session, SessionId, SessionParameters, SessionReport,
},
signature::Keypair,
};
use manul_example::simple::{SimpleProtocol, SimpleProtocolEntryPoint};
Expand All @@ -15,139 +18,6 @@ use tokio::{
sync::mpsc,
time::{sleep, Duration},
};
use tracing::{debug, trace};

struct MessageOut<SP: SessionParameters> {
from: SP::Verifier,
to: SP::Verifier,
message: Message<SP::Verifier>,
}

struct MessageIn<SP: SessionParameters> {
from: SP::Verifier,
message: Message<SP::Verifier>,
}

/// Runs a session. Simulates what each participating party would run as the protocol progresses.
async fn run_session<P, SP>(
tx: mpsc::Sender<MessageOut<SP>>,
rx: mpsc::Receiver<MessageIn<SP>>,
session: Session<P, SP>,
) -> Result<SessionReport<P, SP>, LocalError>
where
P: Protocol<SP::Verifier>,
SP: SessionParameters,
{
let rng = &mut OsRng;

let mut rx = rx;

let mut session = session;
// Some rounds can finalize early and put off sending messages to the next round. Such messages
// will be stored here and applied after the messages for this round are sent.
let mut cached_messages = Vec::new();

let key = session.verifier();

// Each iteration of the loop progresses the session as follows:
// - Send out messages as dictated by the session "destinations".
// - Apply any cached messages.
// - Enter a nested loop:
// - Try to finalize the session; if we're done, exit the inner loop.
// - Wait until we get an incoming message.
// - Process the message we received and continue the loop.
// - When all messages have been sent and received as specified by the protocol, finalize the
// round.
// - If the protocol outcome is a new round, go to the top of the loop and start over with a
// new session.
loop {
debug!("{key:?}: *** starting round {:?} ***", session.round_id());

// This is kept in the main task since it's mutable,
// and we don't want to bother with synchronization.
let mut accum = session.make_accumulator();

// Note: generating/sending messages and verifying newly received messages
// can be done in parallel, with the results being assembled into `accum`
// sequentially in the host task.

let destinations = session.message_destinations();
for destination in destinations.iter() {
// In production usage, this will happen in a spawned task
// (since it can take some time to create a message),
// and the artifact will be sent back to the host task
// to be added to the accumulator.
let (message, artifact) = session.make_message(rng, destination)?;
debug!("{key:?}: Sending a message to {destination:?}",);
tx.send(MessageOut {
from: key.clone(),
to: destination.clone(),
message,
})
.await
.unwrap();

// This would happen in a host task
session.add_artifact(&mut accum, artifact)?;
}

for preprocessed in cached_messages {
// In production usage, this would happen in a spawned task and relayed back to the main task.
debug!("{key:?}: Applying a cached message");
let processed = session.process_message(preprocessed);

// This would happen in a host task.
session.add_processed_message(&mut accum, processed)?;
}

loop {
match session.can_finalize(&accum) {
CanFinalize::Yes => break,
CanFinalize::NotYet => {}
// Due to already registered invalid messages from nodes,
// even if the remaining nodes send correct messages, it won't be enough.
// Terminating.
CanFinalize::Never => {
tracing::warn!("{key:?}: This session cannot ever be finalized. Terminating.");
return session.terminate_due_to_errors(accum);
}
}

debug!("{key:?}: Waiting for a message");
let incoming = rx.recv().await.unwrap();

// Perform quick checks before proceeding with the verification.
match session
.preprocess_message(&mut accum, &incoming.from, incoming.message)?
.ok()
{
Some(preprocessed) => {
// In production usage, this would happen in a separate task.
debug!("{key:?}: Applying a message from {:?}", incoming.from);
let processed = session.process_message(preprocessed);
// In production usage, this would be a host task.
session.add_processed_message(&mut accum, processed)?;
}
None => {
trace!("{key:?} Pre-processing complete. Current state: {accum:?}")
}
}
}

debug!("{key:?}: Finalizing the round");

match session.finalize_round(rng, accum)? {
RoundOutcome::Finished(report) => break Ok(report),
RoundOutcome::AnotherRound {
session: new_session,
cached_messages: new_cached_messages,
} => {
session = new_session;
cached_messages = new_cached_messages;
}
}
}
}

async fn message_dispatcher<SP>(
txs: BTreeMap<SP::Verifier, mpsc::Sender<MessageIn<SP>>>,
Expand Down Expand Up @@ -217,8 +87,9 @@ where
let handles = rxs
.into_iter()
.zip(sessions.into_iter())
.map(|(rx, session)| {
let node_task = run_session(dispatcher_tx.clone(), rx, session);
.map(|(mut rx, session)| {
let tx = dispatcher_tx.clone();
let node_task = async move { run_session(&mut OsRng, &tx, &mut rx, session).await };
tokio::spawn(node_task)
})
.collect::<Vec<_>>();
Expand Down
2 changes: 2 additions & 0 deletions manul/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ rand = { version = "0.8", default-features = false, optional = true }
serde-persistent-deserializer = { version = "0.3", optional = true }
postcard = { version = "1", default-features = false, features = ["alloc"], optional = true }
serde_json = { version = "1", default-features = false, features = ["alloc"], optional = true }
tokio = { version = "1", default-features = false, features = ["sync"], optional = true }

[dev-dependencies]
impls = "1"
Expand All @@ -43,6 +44,7 @@ tracing = { version = "0.1", default-features = false, features = ["std"] }

[features]
dev = ["rand", "postcard", "serde_json", "tracing/std", "serde-persistent-deserializer"]
tokio = ["dep:tokio"]

[package.metadata.docs.rs]
all-features = true
Expand Down
3 changes: 3 additions & 0 deletions manul/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ mod session;
mod transcript;
mod wire_format;

#[cfg(feature = "tokio")]
pub mod tokio;

pub use crate::protocol::{LocalError, RemoteError};
pub use evidence::{Evidence, EvidenceError};
pub use message::{Message, VerifiedMessage};
Expand Down
169 changes: 169 additions & 0 deletions manul/src/session/tokio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
//! High-level API for executing sessions in `tokio` tasks.

use alloc::{format, vec::Vec};

use rand_core::CryptoRngCore;
use tokio::sync::mpsc;
use tracing::{debug, trace};

use super::{
message::Message,
session::{CanFinalize, RoundOutcome, Session, SessionParameters},
transcript::SessionReport,
LocalError,
};
use crate::protocol::Protocol;

/// The outgoing message from a local session.
#[derive(Debug)]
pub struct MessageOut<SP: SessionParameters> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MessageOut and MessageIn type do not seem to be tokio-specific. Are they generally useful you reckon or will users want to implement their own versions? If the latter, why would they want that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are just named tuples with returned stuff. The user is free to re-wrap it differently.

/// The ID of the session that created the message.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this accurate? Given there is a public SessionId type, it is a bit confusing to read this and then read SP::Verifier which boils down to a PartyId. Is it "The ID of the party that created the message in this session" perhaps?

///
/// Useful when there are several sessions running on a node, pushing messages into the same channel.
pub from: SP::Verifier,
/// The ID of the session the message is intended for.
pub to: SP::Verifier,
/// The message to be sent.
///
/// Note that the caller is responsible for encrypting the message and attaching authentication info.
pub message: Message<SP::Verifier>,
}

/// The incoming message from a remote session.
#[derive(Debug)]
pub struct MessageIn<SP: SessionParameters> {
/// The ID of the session the message originated from.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

///
/// It is assumed that the message's authentication info has been checked at this point.
pub from: SP::Verifier,
/// The incoming message.
pub message: Message<SP::Verifier>,
}

/// Executes the session waiting for the messages from the `rx` channel
/// and pushing outgoing messages into the `tx` channel.
pub async fn run_session<P, SP>(
rng: &mut impl CryptoRngCore,
tx: &mpsc::Sender<MessageOut<SP>>,
rx: &mut mpsc::Receiver<MessageIn<SP>>,
session: Session<P, SP>,
) -> Result<SessionReport<P, SP>, LocalError>
where
P: Protocol<SP::Verifier>,
SP: SessionParameters,
{
let mut session = session;
// Some rounds can finalize early and put off sending messages to the next round. Such messages
// will be stored here and applied after the messages for this round are sent.
let mut cached_messages = Vec::new();

let key = session.verifier();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let key = session.verifier();
let verifier = session.verifier();

…or something more specific to its role? Maybe just from?


// Each iteration of the loop progresses the session as follows:
// - Send out messages as dictated by the session "destinations".
// - Apply any cached messages.
// - Enter a nested loop:
// - Try to finalize the session; if we're done, exit the inner loop.
// - Wait until we get an incoming message.
// - Process the message we received and continue the loop.
// - When all messages have been sent and received as specified by the protocol, finalize the
// round.
// - If the protocol outcome is a new round, go to the top of the loop and start over with a
// new session.
loop {
debug!("{key:?}: *** starting round {:?} ***", session.round_id());

// This is kept in the main task since it's mutable,
// and we don't want to bother with synchronization.
let mut accum = session.make_accumulator();

// Note: generating/sending messages and verifying newly received messages
// can be done in parallel, with the results being assembled into `accum`
// sequentially in the host task.

let destinations = session.message_destinations();
for destination in destinations.iter() {
// In production usage, this will happen in a spawned task
// (since it can take some time to create a message),
// and the artifact will be sent back to the host task
// to be added to the accumulator.
let (message, artifact) = session.make_message(rng, destination)?;
debug!("{key:?}: Sending a message to {destination:?}",);
tx.send(MessageOut {
from: session.verifier().clone(),
to: destination.clone(),
message,
})
.await
.map_err(|err| {
LocalError::new(format!(
"Failed to send a message from {:?} to {:?}: {err}",
session.verifier(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
session.verifier(),
key,

destination
))
})?;

// This would happen in a host task
session.add_artifact(&mut accum, artifact)?;
}

for preprocessed in cached_messages {
// In production usage, this would happen in a spawned task and relayed back to the main task.
debug!("{key:?}: Applying a cached message");
let processed = session.process_message(preprocessed);

// This would happen in a host task.
session.add_processed_message(&mut accum, processed)?;
}

loop {
match session.can_finalize(&accum) {
CanFinalize::Yes => break,
CanFinalize::NotYet => {}
// Due to already registered invalid messages from nodes,
// even if the remaining nodes send correct messages, it won't be enough.
// Terminating.
CanFinalize::Never => {
tracing::warn!("{key:?}: This session cannot ever be finalized. Terminating.");
return session.terminate_due_to_errors(accum);
}
}

debug!("{key:?}: Waiting for a message");
let message_in = rx
.recv()
.await
.ok_or_else(|| LocalError::new("Failed to receive a message"))?;

// Perform quick checks before proceeding with the verification.
match session
.preprocess_message(&mut accum, &message_in.from, message_in.message)?
.ok()
{
Some(preprocessed) => {
// In production usage, this would happen in a separate task.
debug!("{key:?}: Applying a message from {:?}", message_in.from);
let processed = session.process_message(preprocessed);
// In production usage, this would be a host task.
session.add_processed_message(&mut accum, processed)?;
}
None => {
trace!("{key:?} Pre-processing complete. Current state: {accum:?}")
}
}
}

debug!("{key:?}: Finalizing the round");

match session.finalize_round(rng, accum)? {
RoundOutcome::Finished(report) => break Ok(report),
RoundOutcome::AnotherRound {
session: new_session,
cached_messages: new_cached_messages,
} => {
session = new_session;
cached_messages = new_cached_messages;
}
}
}
}
Loading