From 3764e7f7844c7f77ccc7ed5792b7ded01d89fe4e Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Fri, 2 Feb 2024 09:35:26 +1100 Subject: [PATCH 1/4] A draft API for validation of replicated shares This is a simple stream adapter that performs inline validation of anything that is replicated. It will error at the end of the stream if it finds a mismatch. (Future work might include periodic checkpoints.) The reason I'm putting this up is that it is a great example of how our borrowed contexts are a giant pain in the posterior. I'm starting to think that we need to think about moving to something with Arc so we can avoid this lifetime business. It's a real drag. --- ipa-core/src/protocol/basics/mod.rs | 1 + ipa-core/src/protocol/basics/validate.rs | 218 +++++++++++++++++++++++ 2 files changed, 219 insertions(+) create mode 100644 ipa-core/src/protocol/basics/validate.rs diff --git a/ipa-core/src/protocol/basics/mod.rs b/ipa-core/src/protocol/basics/mod.rs index 92ebdbfc2..e61f711d7 100644 --- a/ipa-core/src/protocol/basics/mod.rs +++ b/ipa-core/src/protocol/basics/mod.rs @@ -7,6 +7,7 @@ mod reshare; mod reveal; mod share_known_value; pub mod sum_of_product; +pub mod validate; pub use check_zero::check_zero; pub use if_else::if_else; diff --git a/ipa-core/src/protocol/basics/validate.rs b/ipa-core/src/protocol/basics/validate.rs new file mode 100644 index 000000000..e21810fa8 --- /dev/null +++ b/ipa-core/src/protocol/basics/validate.rs @@ -0,0 +1,218 @@ +use std::{ + convert::Infallible, + marker::PhantomData, + pin::Pin, + task::{Context as TaskContext, Poll}, +}; + +use futures::{ + future::try_join, + stream::{Fuse, Stream, StreamExt}, + Future, FutureExt, +}; +use generic_array::GenericArray; +use pin_project::pin_project; +use sha2::{ + digest::{typenum::Unsigned, FixedOutput, OutputSizeUser}, + Digest, Sha256, +}; + +use crate::{ + error::Error, + ff::Serializable, + helpers::{Direction, Message}, + protocol::{context::Context, RecordId}, + secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue}, +}; + +type HashFunction = Sha256; +type HashSize = ::OutputSize; +type HashOutputArray = [u8; ::USIZE]; + +#[derive(Debug, Clone, PartialEq, Eq)] +struct HashValue(GenericArray); + +impl Serializable for HashValue { + type Size = HashSize; + type DeserializationError = Infallible; + + fn serialize(&self, buf: &mut GenericArray) { + buf.copy_from_slice(self.0.as_slice()) + } + + fn deserialize(buf: &GenericArray) -> Result { + Ok(Self(buf.to_owned())) + } +} + +impl Message for HashValue {} + +struct ReplicatedValidatorFinalization { + f: Pin>>>, + ctx: C, +} + +impl ReplicatedValidatorFinalization { + fn new(active: ReplicatedValidatorActive) -> Self { + let ReplicatedValidatorActive { + ctx, + left_hash, + right_hash, + } = active; + // Ugh: The version of sha2 we currently use doesn't use the same GenericArray version as we do. + let left_hash = HashValue(GenericArray::from(::from( + left_hash.finalize_fixed(), + ))); + let right_hash = HashValue(GenericArray::from(::from( + right_hash.finalize_fixed(), + ))); + let left_peer = ctx.role().peer(Direction::Left); + let right_peer = ctx.role().peer(Direction::Left); + let ctx_ref = &ctx; + + let f = Box::pin(async move { + try_join( + ctx_ref + .send_channel(left_peer) + .send(RecordId::FIRST, left_hash.clone()), + ctx_ref + .send_channel(right_peer) + .send(RecordId::FIRST, right_hash.clone()), + ) + .await?; + let (left_recvd, right_recvd) = try_join( + ctx_ref.recv_channel(left_peer).receive(RecordId::FIRST), + ctx_ref.recv_channel(right_peer).receive(RecordId::FIRST), + ) + .await?; + if left_hash == left_recvd && right_hash == right_recvd { + Ok(()) + } else { + Err(Error::Internal) // TODO add a code + } + }); + Self { f, ctx } + } + + fn poll(&mut self, cx: &mut TaskContext<'_>) -> Poll> { + self.f.poll_unpin(cx) + } +} + +struct ReplicatedValidatorActive { + ctx: C, + left_hash: Sha256, + right_hash: Sha256, +} + +impl ReplicatedValidatorActive { + fn new(ctx: C) -> Self { + Self { + ctx, + left_hash: HashFunction::new(), + right_hash: HashFunction::new(), + } + } + + fn update(&mut self, s: &S) + where + S: ReplicatedSecretSharing, + V: SharedValue, + { + let mut buf = GenericArray::default(); // ::::Size> + s.left().serialize(&mut buf); + self.left_hash.update(buf.as_slice()); + s.right().serialize(&mut buf); + self.right_hash.update(buf.as_slice()); + } + + fn finalize(self) -> ReplicatedValidatorFinalization { + ReplicatedValidatorFinalization::new(self) + } +} + +enum ReplicatedValidatorState { + /// While the validator is waiting, it holds a context reference. + Pending(Option>), + /// After the validator has taken all of its inputs, it holds a future. + Finalizing(ReplicatedValidatorFinalization), +} + +impl ReplicatedValidatorState { + /// # Panics + /// This panics if it is called after `finalize()`. + fn update(&mut self, s: &S) + where + S: ReplicatedSecretSharing, + V: SharedValue, + { + if let Self::Pending(Some(a)) = self { + a.update(s); + } else { + panic!(); + } + } + + fn poll(&mut self, cx: &mut TaskContext<'_>) -> Poll> { + match self { + Self::Pending(ref mut active) => { + let mut f = active.take().unwrap().finalize(); + let res = f.poll(cx); + *self = ReplicatedValidatorState::Finalizing(f); + res + } + Self::Finalizing(f) => f.poll(cx), + } + } +} + +#[pin_project] +struct ReplicatedValidator { + #[pin] + input: Fuse, + state: ReplicatedValidatorState, + _marker: PhantomData<(S, V)>, +} + +impl ReplicatedValidator { + pub fn new(ctx: C, s: T) -> Self { + Self { + input: s.fuse(), + state: ReplicatedValidatorState::Pending(Some(ReplicatedValidatorActive::new(ctx))), + _marker: PhantomData, + } + } +} + +impl Stream for ReplicatedValidator +where + C: Context + 'static, + T: Stream>, + S: ReplicatedSecretSharing, + V: SharedValue, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + let this = self.project(); + match this.input.poll_next(cx) { + Poll::Ready(Some(v)) => match v { + Ok(v) => { + this.state.update(&v); + Poll::Ready(Some(Ok(v))) + } + Err(e) => Poll::Ready(Some(Err(e))), + }, + Poll::Ready(None) => match this.state.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => Poll::Ready(None), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + }, + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + self.input.size_hint() + } +} From f048b1d8ec283c21d798e004753a839325ce5127 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Mon, 5 Feb 2024 12:35:32 +1100 Subject: [PATCH 2/4] Around we go again --- ipa-core/src/protocol/basics/validate.rs | 120 ++++++++++++++++++----- ipa-macros/src/derive_step/mod.rs | 14 ++- ipa-macros/src/parser.rs | 16 ++- 3 files changed, 111 insertions(+), 39 deletions(-) diff --git a/ipa-core/src/protocol/basics/validate.rs b/ipa-core/src/protocol/basics/validate.rs index e21810fa8..a146b87fa 100644 --- a/ipa-core/src/protocol/basics/validate.rs +++ b/ipa-core/src/protocol/basics/validate.rs @@ -23,6 +23,7 @@ use crate::{ helpers::{Direction, Message}, protocol::{context::Context, RecordId}, secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue}, + seq_join::assert_send, }; type HashFunction = Sha256; @@ -37,7 +38,7 @@ impl Serializable for HashValue { type DeserializationError = Infallible; fn serialize(&self, buf: &mut GenericArray) { - buf.copy_from_slice(self.0.as_slice()) + buf.copy_from_slice(self.0.as_slice()); } fn deserialize(buf: &GenericArray) -> Result { @@ -47,42 +48,42 @@ impl Serializable for HashValue { impl Message for HashValue {} -struct ReplicatedValidatorFinalization { +impl From for HashValue { + fn from(value: HashFunction) -> Self { + // Ugh: The version of sha2 we currently use doesn't use the same GenericArray version as we do. + HashValue(GenericArray::from(::from( + value.finalize_fixed(), + ))) + } +} + +struct ReplicatedValidatorFinalization { f: Pin>>>, - ctx: C, } -impl ReplicatedValidatorFinalization { - fn new(active: ReplicatedValidatorActive) -> Self { +impl ReplicatedValidatorFinalization { + fn new(active: ReplicatedValidatorActive) -> Self { let ReplicatedValidatorActive { ctx, left_hash, right_hash, } = active; - // Ugh: The version of sha2 we currently use doesn't use the same GenericArray version as we do. - let left_hash = HashValue(GenericArray::from(::from( - left_hash.finalize_fixed(), - ))); - let right_hash = HashValue(GenericArray::from(::from( - right_hash.finalize_fixed(), - ))); + let left_hash = HashValue::from(left_hash); + let right_hash = HashValue::from(right_hash); let left_peer = ctx.role().peer(Direction::Left); - let right_peer = ctx.role().peer(Direction::Left); - let ctx_ref = &ctx; + let right_peer = ctx.role().peer(Direction::Right); - let f = Box::pin(async move { + let f = Box::pin(assert_send(async move { try_join( - ctx_ref - .send_channel(left_peer) + ctx.send_channel(left_peer) .send(RecordId::FIRST, left_hash.clone()), - ctx_ref - .send_channel(right_peer) + ctx.send_channel(right_peer) .send(RecordId::FIRST, right_hash.clone()), ) .await?; let (left_recvd, right_recvd) = try_join( - ctx_ref.recv_channel(left_peer).receive(RecordId::FIRST), - ctx_ref.recv_channel(right_peer).receive(RecordId::FIRST), + ctx.recv_channel(left_peer).receive(RecordId::FIRST), + ctx.recv_channel(right_peer).receive(RecordId::FIRST), ) .await?; if left_hash == left_recvd && right_hash == right_recvd { @@ -90,8 +91,8 @@ impl ReplicatedValidatorFinalization { } else { Err(Error::Internal) // TODO add a code } - }); - Self { f, ctx } + })); + Self { f } } fn poll(&mut self, cx: &mut TaskContext<'_>) -> Poll> { @@ -126,16 +127,16 @@ impl ReplicatedValidatorActive { self.right_hash.update(buf.as_slice()); } - fn finalize(self) -> ReplicatedValidatorFinalization { + fn finalize(self) -> ReplicatedValidatorFinalization { ReplicatedValidatorFinalization::new(self) } } enum ReplicatedValidatorState { /// While the validator is waiting, it holds a context reference. - Pending(Option>), + Pending(Option>>), /// After the validator has taken all of its inputs, it holds a future. - Finalizing(ReplicatedValidatorFinalization), + Finalizing(ReplicatedValidatorFinalization), } impl ReplicatedValidatorState { @@ -178,7 +179,9 @@ impl ReplicatedValidator { pub fn new(ctx: C, s: T) -> Self { Self { input: s.fuse(), - state: ReplicatedValidatorState::Pending(Some(ReplicatedValidatorActive::new(ctx))), + state: ReplicatedValidatorState::Pending(Some(Box::new( + ReplicatedValidatorActive::new(ctx), + ))), _marker: PhantomData, } } @@ -216,3 +219,66 @@ where self.input.size_hint() } } + +#[cfg(test)] +mod test { + use std::iter::repeat_with; + + use futures::stream::{iter as stream_iter, Stream, StreamExt, TryStreamExt}; + + use crate::{ + error::Error, + ff::Fp31, + helpers::Direction, + protocol::{basics::validate::ReplicatedValidator, context::Context, RecordId}, + rand::{thread_rng, Rng}, + secret_sharing::{ + replicated::{ + semi_honest::AdditiveShare as SemiHonestReplicated, ReplicatedSecretSharing, + }, + SharedValue, + }, + test_fixture::{Reconstruct, Runner, TestWorld}, + }; + + fn assert_stream>, T>(s: S) -> S { + s + } + + /// Successfully validate some shares. + #[tokio::test] + pub async fn simple() { + let mut rng = thread_rng(); + let world = TestWorld::default(); + + let input = repeat_with(|| rng.gen::()) + .take(10) + .collect::>(); + let result = world + .semi_honest(input.into_iter(), |ctx, shares| async move { + let ctx = ctx.set_total_records(shares.len()); + let s = stream_iter(shares).map(|x| Ok(x)); + let vs = ReplicatedValidator::new(ctx.narrow("validate"), s); + let sum = assert_stream(vs) + .try_fold(Fp31::ZERO, |sum, value| async move { + Ok(sum + value.left() - value.right()) + }) + .await?; + // This value should sum to zero now, so replicate the value. + // (We don't care here that this reveals our share to other helpers, it's just a test.) + ctx.send_channel(ctx.role().peer(Direction::Right)) + .send(RecordId::FIRST, sum) + .await?; + let left = ctx + .recv_channel(ctx.role().peer(Direction::Left)) + .receive(RecordId::FIRST) + .await?; + Ok(SemiHonestReplicated::new(left, sum)) + }) + .await + .map(Result::<_, Error>::unwrap) + .reconstruct(); + + assert_eq!(Fp31::ZERO, result); + } +} diff --git a/ipa-macros/src/derive_step/mod.rs b/ipa-macros/src/derive_step/mod.rs index 9093916e5..f03060422 100644 --- a/ipa-macros/src/derive_step/mod.rs +++ b/ipa-macros/src/derive_step/mod.rs @@ -44,7 +44,7 @@ use syn::{parse_macro_input, DeriveInput}; use crate::{ parser::{group_by_modules, ipa_state_transition_map, StepMetaData}, - tree::Node, + tree::{self, Node}, }; const MAX_DYNAMIC_STEPS: usize = 1024; @@ -115,7 +115,7 @@ fn impl_as_ref(ident: &syn::Ident, data: &syn::DataEnum) -> Result Result>(); let steps_array_ident = format_ident!("{}_DYNAMIC_STEP", ident_upper_case); const_arrays.extend(quote!( @@ -272,9 +272,8 @@ fn get_meta_data_for( 1 => { Ok(target_steps[0] .iter() - .map(|s| - // we want to retain the references to the parents, so we use `upgrade()` - s.upgrade()) + // we want to retain the references to the parents, so we use `upgrade()` + .map(tree::Node::upgrade) .collect::>()) } _ => Err(syn::Error::new_spanned( @@ -314,8 +313,7 @@ fn get_dynamic_step_count(variant: &syn::Variant) -> Result { dynamic_attr, format!( "ipa_macros::step \"dynamic\" attribute expects a number of steps \ - (<= {}) in parentheses: #[dynamic(...)].", - MAX_DYNAMIC_STEPS, + (<= {MAX_DYNAMIC_STEPS}) in parentheses: #[dynamic(...)].", ), )), } diff --git a/ipa-macros/src/parser.rs b/ipa-macros/src/parser.rs index 25dd32bbe..691341c64 100644 --- a/ipa-macros/src/parser.rs +++ b/ipa-macros/src/parser.rs @@ -71,7 +71,10 @@ pub(crate) fn read_steps_file(file_path: &str) -> Vec { let mut file = std::fs::File::open(path).expect("Could not open the steps file"); let mut contents = String::new(); file.read_to_string(&mut contents).unwrap(); - contents.lines().map(|s| s.to_owned()).collect::>() + contents + .lines() + .map(std::borrow::ToOwned::to_owned) + .collect::>() } /// Constructs a tree structure with nodes that contain the `Step` instances. @@ -109,10 +112,15 @@ pub(crate) fn construct_tree(steps: Vec) -> Node { /// Split a single substep full path into the module path and the step's name. /// /// # Example +/// ```ignore /// input = "ipa::protocol::modulus_conversion::convert_shares::Step::xor1" /// output = ("ipa::protocol::modulus_conversion::convert_shares::Step", "xor1") +/// ``` pub(crate) fn split_step_module_and_name(input: &str) -> (String, String) { - let mod_parts = input.split("::").map(|s| s.to_owned()).collect::>(); + let mod_parts = input + .split("::") + .map(std::borrow::ToOwned::to_owned) + .collect::>(); let (substep_name, path) = mod_parts.split_last().unwrap(); (path.join("::"), substep_name.to_owned()) } @@ -123,8 +131,8 @@ pub(crate) fn split_step_module_and_name(input: &str) -> (String, String) { /// # Example /// Let say we have the following steps: /// -/// - StepA::A1 -/// - StepC::C1/StepD::D1/StepA::A2 +/// - `StepA::A1` +/// - `StepC::C1/StepD::D1/StepA::A2` /// /// If we generate code for each node while traversing, we will end up with the following: /// From 124ea04aeffe1cfd3305ce3169f06784028ddfd3 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Mon, 5 Feb 2024 12:46:11 +1100 Subject: [PATCH 3/4] Around we go again --- ipa-core/src/protocol/basics/validate.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/protocol/basics/validate.rs b/ipa-core/src/protocol/basics/validate.rs index a146b87fa..1341b46ff 100644 --- a/ipa-core/src/protocol/basics/validate.rs +++ b/ipa-core/src/protocol/basics/validate.rs @@ -58,7 +58,7 @@ impl From for HashValue { } struct ReplicatedValidatorFinalization { - f: Pin>>>, + f: Pin> + Send)>>, } impl ReplicatedValidatorFinalization { From 211a982169c992a1d98402ddbe0d9e0ac5575c43 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Wed, 7 Feb 2024 16:23:05 +1100 Subject: [PATCH 4/4] Working! --- ipa-core/src/error.rs | 2 + ipa-core/src/protocol/basics/validate.rs | 110 +++++++++++++++++------ 2 files changed, 87 insertions(+), 25 deletions(-) diff --git a/ipa-core/src/error.rs b/ipa-core/src/error.rs index 0c2705f1f..6b99bae13 100644 --- a/ipa-core/src/error.rs +++ b/ipa-core/src/error.rs @@ -36,6 +36,8 @@ pub enum Error { MaliciousSecurityCheckFailed, #[error("malicious reveal failed")] MaliciousRevealFailed, + #[error("share values were inconsistent between helpers")] + InconsistentShares, #[error("problem during IO: {0}")] Io(#[from] std::io::Error), // TODO remove if this https://github.com/awslabs/shuttle/pull/109 gets approved diff --git a/ipa-core/src/protocol/basics/validate.rs b/ipa-core/src/protocol/basics/validate.rs index 1341b46ff..80a42abc5 100644 --- a/ipa-core/src/protocol/basics/validate.rs +++ b/ipa-core/src/protocol/basics/validate.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] // Not wired in yet. + use std::{ convert::Infallible, marker::PhantomData, @@ -57,16 +59,18 @@ impl From for HashValue { } } -struct ReplicatedValidatorFinalization { - f: Pin> + Send)>>, +/// The finalizing state for the validator. +struct ReplicatedValidatorFinalization<'a> { + f: Pin> + Send + 'a)>>, } -impl ReplicatedValidatorFinalization { - fn new(active: ReplicatedValidatorActive) -> Self { +impl<'a> ReplicatedValidatorFinalization<'a> { + fn new(active: ReplicatedValidatorActive<'a, C>) -> Self { let ReplicatedValidatorActive { ctx, left_hash, right_hash, + .. } = active; let left_hash = HashValue::from(left_hash); let right_hash = HashValue::from(right_hash); @@ -89,7 +93,7 @@ impl ReplicatedValidatorFinalization { if left_hash == left_recvd && right_hash == right_recvd { Ok(()) } else { - Err(Error::Internal) // TODO add a code + Err(Error::InconsistentShares) } })); Self { f } @@ -100,18 +104,21 @@ impl ReplicatedValidatorFinalization { } } -struct ReplicatedValidatorActive { +/// The active state for the validator. +struct ReplicatedValidatorActive<'a, C: 'a> { ctx: C, left_hash: Sha256, right_hash: Sha256, + _marker: PhantomData<&'a ()>, } -impl ReplicatedValidatorActive { +impl<'a, C: Context + 'a> ReplicatedValidatorActive<'a, C> { fn new(ctx: C) -> Self { Self { ctx, left_hash: HashFunction::new(), right_hash: HashFunction::new(), + _marker: PhantomData, } } @@ -127,19 +134,19 @@ impl ReplicatedValidatorActive { self.right_hash.update(buf.as_slice()); } - fn finalize(self) -> ReplicatedValidatorFinalization { + fn finalize(self) -> ReplicatedValidatorFinalization<'a> { ReplicatedValidatorFinalization::new(self) } } -enum ReplicatedValidatorState { +enum ReplicatedValidatorState<'a, C: 'a> { /// While the validator is waiting, it holds a context reference. - Pending(Option>>), + Pending(Option>>), /// After the validator has taken all of its inputs, it holds a future. - Finalizing(ReplicatedValidatorFinalization), + Finalizing(ReplicatedValidatorFinalization<'a>), } -impl ReplicatedValidatorState { +impl<'a, C: Context + 'a> ReplicatedValidatorState<'a, C> { /// # Panics /// This panics if it is called after `finalize()`. fn update(&mut self, s: &S) @@ -167,29 +174,37 @@ impl ReplicatedValidatorState { } } +/// A `ReplicatedValidator` takes a stream of replicated shares of anything +/// and produces a stream of the same values, without modifying them. +/// The only thing it does is check that the values are consistent across +/// all three helpers using the provided context. +/// To do this, it sends a single message. +/// +/// If validation passes, the stream is completely transparent. +/// If validation fails, the stream will error before it closes. #[pin_project] -struct ReplicatedValidator { +struct ReplicatedValidator<'a, C: 'a, T: Stream, S, V> { #[pin] input: Fuse, - state: ReplicatedValidatorState, + state: ReplicatedValidatorState<'a, C>, _marker: PhantomData<(S, V)>, } -impl ReplicatedValidator { - pub fn new(ctx: C, s: T) -> Self { +impl<'a, C: Context + 'a, T: Stream, S, V> ReplicatedValidator<'a, C, T, S, V> { + pub fn new(ctx: &C, s: T) -> Self { Self { input: s.fuse(), state: ReplicatedValidatorState::Pending(Some(Box::new( - ReplicatedValidatorActive::new(ctx), + ReplicatedValidatorActive::new(ctx.set_total_records(1)), ))), _marker: PhantomData, } } } -impl Stream for ReplicatedValidator +impl<'a, C, T, S, V> Stream for ReplicatedValidator<'a, C, T, S, V> where - C: Context + 'static, + C: Context + 'a, T: Stream>, S: ReplicatedSecretSharing, V: SharedValue, @@ -220,7 +235,7 @@ where } } -#[cfg(test)] +#[cfg(all(test, unit_test))] mod test { use std::iter::repeat_with; @@ -228,8 +243,8 @@ mod test { use crate::{ error::Error, - ff::Fp31, - helpers::Direction, + ff::{Field, Fp31}, + helpers::{Direction, Role}, protocol::{basics::validate::ReplicatedValidator, context::Context, RecordId}, rand::{thread_rng, Rng}, secret_sharing::{ @@ -256,14 +271,14 @@ mod test { .collect::>(); let result = world .semi_honest(input.into_iter(), |ctx, shares| async move { - let ctx = ctx.set_total_records(shares.len()); - let s = stream_iter(shares).map(|x| Ok(x)); - let vs = ReplicatedValidator::new(ctx.narrow("validate"), s); + let s = stream_iter(shares).map(Ok); + let vs = ReplicatedValidator::new(&ctx.narrow("validate"), s); let sum = assert_stream(vs) .try_fold(Fp31::ZERO, |sum, value| async move { Ok(sum + value.left() - value.right()) }) .await?; + let ctx = ctx.set_total_records(1); // This value should sum to zero now, so replicate the value. // (We don't care here that this reveals our share to other helpers, it's just a test.) ctx.send_channel(ctx.role().peer(Direction::Right)) @@ -281,4 +296,49 @@ mod test { assert_eq!(Fp31::ZERO, result); } + + #[tokio::test] + pub async fn inconsistent() { + let mut rng = thread_rng(); + let world = TestWorld::default(); + + let damage = |role| { + let mut tweak = role == Role::H3; + move |v: SemiHonestReplicated| -> SemiHonestReplicated { + if tweak { + tweak = false; + SemiHonestReplicated::new(v.left(), v.right() + Fp31::ONE) + } else { + v + } + } + }; + + let input = repeat_with(|| rng.gen::()) + .take(10) + .collect::>(); + let result = world + .semi_honest(input.into_iter(), |ctx, shares| async move { + let s = stream_iter(shares).map(damage(ctx.role())).map(Ok); + let vs = ReplicatedValidator::new(&ctx.narrow("validate"), s); + let sum = assert_stream(vs) + .try_fold(Fp31::ZERO, |sum, value| async move { + Ok(sum + value.left() - value.right()) + }) + .await?; + Ok(sum) // This will be not be reached by 2/3 helpers. + }) + .await; + + // With just one error having been introduced, two of three helpers will error out. + assert!(matches!( + result[0].as_ref().unwrap_err(), + Error::InconsistentShares + )); + assert!(result[1].is_ok()); + assert!(matches!( + result[2].as_ref().unwrap_err(), + Error::InconsistentShares + )); + } }