Skip to content

Commit

Permalink
polishing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Hannah Davis committed Oct 22, 2024
1 parent d927640 commit 5a46f32
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 77 deletions.
5 changes: 3 additions & 2 deletions benches/speed_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ use prio::dp::distributions::DiscreteGaussian;
use prio::idpf::test_utils::generate_zipf_distributed_batch;
#[cfg(feature = "experimental")]
use prio::vdaf::prio2::Prio2;
#[cfg(feature = "experimental")]
use prio::vdaf::vidpf::VidpfServerId;
use prio::{
benchmarked::*,
field::{random_vector, Field128 as F, FieldElement},
flp::gadgets::Mul,
vdaf::{prio3::Prio3, Aggregator, Client},
vidpf::VidpfServerId,
};
#[cfg(feature = "experimental")]
use prio::{
Expand Down Expand Up @@ -815,7 +816,7 @@ fn vidpf(c: &mut Criterion) {

b.iter(|| {
let _ = vidpf
.eval(&VidpfServerId::S0, &keys[0], &public, &input, NONCE)
.eval(VidpfServerId::S0, &keys[0], &public, &input, NONCE)
.unwrap();
});
});
Expand Down
71 changes: 21 additions & 50 deletions src/vdaf/mastic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,10 @@ where

let leader_measurement_share =
self.vidpf
.eval_root(&VidpfServerId::S0, &vidpf_keys[0], &public_share, nonce)?;
.eval_root(VidpfServerId::S0, &vidpf_keys[0], &public_share, nonce)?;
let helper_measurement_share =
self.vidpf
.eval_root(&VidpfServerId::S1, &vidpf_keys[1], &public_share, nonce)?;
.eval_root(VidpfServerId::S1, &vidpf_keys[1], &public_share, nonce)?;

let [leader_szk_proof_share, helper_szk_proof_share] = self.szk.prove(
leader_measurement_share.as_ref(),
Expand Down Expand Up @@ -537,7 +537,7 @@ where
);
let mut cache_tree = BinaryTree::<VidpfEvalCache<VidpfWeight<T::Field>>>::default();
let cache = VidpfEvalCache::<VidpfWeight<T::Field>>::init_from_key(
&id,
id,
&input_share.vidpf_key,
&self.vidpf.weight_parameter,
);
Expand All @@ -546,7 +546,7 @@ where
.expect("Should alwys be able to insert into empty tree at root");
for prefix in agg_param.level_and_prefixes.prefixes() {
let mut value_share = self.vidpf.eval_with_cache(
&id,
id,
&input_share.vidpf_key,
public_share,
prefix,
Expand All @@ -558,7 +558,7 @@ where
}
let root_share_opt = if agg_param.require_weight_check {
Some(self.vidpf.eval_root_with_cache(
&id,
id,
&input_share.vidpf_key,
public_share,
&mut cache_tree,
Expand Down Expand Up @@ -624,53 +624,24 @@ where
))?;
if inputs_iter.next().is_some() {
return Err(VdafError::Uncategorized(
"more than 2 prepare shares".to_string(),
"Received more than two prepare shares".to_string(),
));
};

match (leader_share, helper_share) {
(
MasticPrepareShare {
vidpf_proof: leader_vidpf_proof,
szk_query_share_opt: Some(leader_query_share),
},
MasticPrepareShare {
vidpf_proof: helper_vidpf_proof,
szk_query_share_opt: Some(helper_query_share),
},
) => {
if leader_vidpf_proof == helper_vidpf_proof {
Ok(Some(SzkQueryShare::merge_verifiers(
leader_query_share,
helper_query_share,
)))
} else {
Err(VdafError::Uncategorized(
"Vidpf proof verification failed".to_string(),
))
}
}
(
MasticPrepareShare {
vidpf_proof: leader_vidpf_proof,
szk_query_share_opt: None,
},
MasticPrepareShare {
vidpf_proof: helper_vidpf_proof,
szk_query_share_opt: None,
},
) => {
if leader_vidpf_proof == helper_vidpf_proof {
Ok(None)
} else {
Err(VdafError::Uncategorized(
"Vidpf proof verification failed".to_string(),
))
}
}
_ => Err(VdafError::Uncategorized(
"Prepare state and message disagree on whether Szk verification should occur"
.to_string(),
if leader_share.vidpf_proof != helper_share.vidpf_proof {
return Err(VdafError::Uncategorized(
"Vidpf proof verification failed".to_string(),
));
};
match (
leader_share.szk_query_share_opt,
helper_share.szk_query_share_opt,
) {
(Some(leader_query_share), Some(helper_query_share)) => Ok(Some(
SzkQueryShare::merge_verifiers(leader_query_share, helper_query_share),
)),
(None, None) => Ok(None),
(_, _) => Err(VdafError::Uncategorized(
"Only one of leader and helper query shares is present".to_string(),
)),
}
}
Expand Down
73 changes: 48 additions & 25 deletions src/vidpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ impl<W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {
/// input's weight.
pub fn eval(
&self,
id: &VidpfServerId,
id: VidpfServerId,
key: &VidpfKey,
public: &VidpfPublicShare<W>,
input: &VidpfInput,
Expand Down Expand Up @@ -230,7 +230,7 @@ impl<W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {
/// cache.
pub fn eval_with_cache(
&self,
id: &VidpfServerId,
id: VidpfServerId,
key: &VidpfKey,
public: &VidpfPublicShare<W>,
input: &VidpfInput,
Expand Down Expand Up @@ -280,7 +280,7 @@ impl<W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {
/// state, and returns a new state and a share of the input's weight at that level.
fn eval_next(
&self,
id: &VidpfServerId,
id: VidpfServerId,
public: &VidpfPublicShare<W>,
input: &VidpfInput,
level: usize,
Expand All @@ -306,7 +306,7 @@ impl<W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {
let zero = <W as IdpfValue>::zero(&self.weight_parameter);
let mut y = <W as IdpfValue>::conditional_select(&zero, &cw.weight, next_control_bit);
y += w_i;
y.conditional_negate(Choice::from(*id));
y.conditional_negate(Choice::from(id));

let pi_i = &state.proof;
let cs_i = public.cs.get(level).ok_or(VidpfError::IndexLevel)?;
Expand All @@ -328,7 +328,7 @@ impl<W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {

pub(crate) fn eval_root_with_cache(
&self,
id: &VidpfServerId,
id: VidpfServerId,
key: &VidpfKey,
public_share: &VidpfPublicShare<W>,
cache_tree: &mut BinaryTree<VidpfEvalCache<W>>,
Expand Down Expand Up @@ -358,7 +358,7 @@ impl<W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {

pub(crate) fn eval_root(
&self,
id: &VidpfServerId,
id: VidpfServerId,
key: &VidpfKey,
public_share: &VidpfPublicShare<W>,
nonce: &[u8; NONCE_SIZE],
Expand Down Expand Up @@ -615,10 +615,10 @@ pub struct VidpfEvalState {
}

impl VidpfEvalState {
fn init_from_key(id: &VidpfServerId, key: &VidpfKey) -> Self {
fn init_from_key(id: VidpfServerId, key: &VidpfKey) -> Self {
Self {
seed: key.0,
control_bit: Choice::from(*id),
control_bit: Choice::from(id),
proof: VidpfProof::default(),
}
}
Expand All @@ -635,7 +635,7 @@ pub struct VidpfEvalCache<W: VidpfValue> {

impl<W: VidpfValue> VidpfEvalCache<W> {
pub(crate) fn init_from_key(
id: &VidpfServerId,
id: VidpfServerId,
key: &VidpfKey,
length: &W::ValueParameter,
) -> Self {
Expand Down Expand Up @@ -839,15 +839,36 @@ mod tests {
mod vidpf {
use crate::{
bt::{BinaryTree, Path},
codec::{Encode, ParameterizedDecode},
idpf::IdpfValue,
vidpf::{
Vidpf, VidpfEvalCache, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare,
VidpfServerId,
},
};
use std::io::Cursor;

use super::{TestWeight, TEST_NONCE, TEST_NONCE_SIZE, TEST_WEIGHT_LEN};

#[test]
fn roundtrip_codec() {
let input = VidpfInput::from_bytes(&[0xFF]);
let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
let (_, public, _, _) = vidpf_gen_setup(&input, &weight);

let mut bytes = vec![];
public.encode(&mut bytes).unwrap();

assert_eq!(public.encoded_len().unwrap(), bytes.len());

let decoded = VidpfPublicShare::<TestWeight>::decode_with_param(
&(8, TEST_WEIGHT_LEN),
&mut Cursor::new(&bytes),
)
.unwrap();
assert_eq!(public, decoded);
}

fn vidpf_gen_setup(
input: &VidpfInput,
weight: &TestWeight,
Expand All @@ -869,10 +890,10 @@ mod tests {
let (vidpf, public, [key_0, key_1], nonce) = vidpf_gen_setup(&input, &weight);

let value_share_0 = vidpf
.eval(&VidpfServerId::S0, &key_0, &public, &input, &nonce)
.eval(VidpfServerId::S0, &key_0, &public, &input, &nonce)
.unwrap();
let value_share_1 = vidpf
.eval(&VidpfServerId::S1, &key_1, &public, &input, &nonce)
.eval(VidpfServerId::S1, &key_1, &public, &input, &nonce)
.unwrap();

assert_eq!(
Expand All @@ -889,10 +910,10 @@ mod tests {
let bad_input = VidpfInput::from_bytes(&[0x00]);
let zero = TestWeight::zero(&TEST_WEIGHT_LEN);
let value_share_0 = vidpf
.eval(&VidpfServerId::S0, &key_0, &public, &bad_input, &nonce)
.eval(VidpfServerId::S0, &key_0, &public, &bad_input, &nonce)
.unwrap();
let value_share_1 = vidpf
.eval(&VidpfServerId::S1, &key_1, &public, &bad_input, &nonce)
.eval(VidpfServerId::S1, &key_1, &public, &bad_input, &nonce)
.unwrap();

assert_eq!(
Expand Down Expand Up @@ -929,18 +950,18 @@ mod tests {
weight: &TestWeight,
nonce: &[u8; TEST_NONCE_SIZE],
) {
let mut state_0 = VidpfEvalState::init_from_key(&VidpfServerId::S0, key_0);
let mut state_1 = VidpfEvalState::init_from_key(&VidpfServerId::S1, key_1);
let mut state_0 = VidpfEvalState::init_from_key(VidpfServerId::S0, key_0);
let mut state_1 = VidpfEvalState::init_from_key(VidpfServerId::S1, key_1);

let n = input.len();
for level in 0..n {
let share_0;
let share_1;
(state_0, share_0) = vidpf
.eval_next(&VidpfServerId::S0, public, input, level, &state_0, nonce)
.eval_next(VidpfServerId::S0, public, input, level, &state_0, nonce)
.unwrap();
(state_1, share_1) = vidpf
.eval_next(&VidpfServerId::S1, public, input, level, &state_1, nonce)
.eval_next(VidpfServerId::S1, public, input, level, &state_1, nonce)
.unwrap();

assert_eq!(
Expand All @@ -964,10 +985,12 @@ mod tests {
let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
let (vidpf, public, keys, nonce) = vidpf_gen_setup(&input, &weight);

equivalence_of_eval_with_caching(&vidpf, &keys, &public, &input, &nonce);
test_equivalence_of_eval_with_caching(&vidpf, &keys, &public, &input, &nonce);
}

fn equivalence_of_eval_with_caching(
/// Ensures that VIDPF outputs match regardless of whether the path to
/// each node is recomputed or cached during evaluation.
fn test_equivalence_of_eval_with_caching(
vidpf: &Vidpf<TestWeight, TEST_NONCE_SIZE>,
[key_0, key_1]: &[VidpfKey; 2],
public: &VidpfPublicShare<TestWeight>,
Expand All @@ -977,12 +1000,12 @@ mod tests {
let mut cache_tree_0 = BinaryTree::<VidpfEvalCache<TestWeight>>::default();
let mut cache_tree_1 = BinaryTree::<VidpfEvalCache<TestWeight>>::default();
let cache_0 = VidpfEvalCache::<TestWeight>::init_from_key(
&VidpfServerId::S0,
VidpfServerId::S0,
key_0,
&vidpf.weight_parameter,
);
let cache_1 = VidpfEvalCache::<TestWeight>::init_from_key(
&VidpfServerId::S1,
VidpfServerId::S1,
key_1,
&vidpf.weight_parameter,
);
Expand All @@ -997,7 +1020,7 @@ mod tests {
for level in 0..n {
let val_share_0 = vidpf
.eval(
&VidpfServerId::S0,
VidpfServerId::S0,
key_0,
public,
&input.prefix(level),
Expand All @@ -1006,7 +1029,7 @@ mod tests {
.unwrap();
let val_share_1 = vidpf
.eval(
&VidpfServerId::S1,
VidpfServerId::S1,
key_1,
public,
&input.prefix(level),
Expand All @@ -1015,7 +1038,7 @@ mod tests {
.unwrap();
let val_share_0_cached = vidpf
.eval_with_cache(
&VidpfServerId::S0,
VidpfServerId::S0,
key_0,
public,
&input.prefix(level),
Expand All @@ -1025,7 +1048,7 @@ mod tests {
.unwrap();
let val_share_1_cached = vidpf
.eval_with_cache(
&VidpfServerId::S1,
VidpfServerId::S1,
key_1,
public,
&input.prefix(level),
Expand Down

0 comments on commit 5a46f32

Please sign in to comment.