Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Prio3MutlihotCountVec #1123

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/flp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ pub trait Type: Sized + Eq + Clone + Debug {
measurement: &Self::Measurement,
) -> Result<Vec<Self::Field>, FlpError>;

/// Decode an aggregate result.
/// Decode an aggregate result. The input is NOT the inverse of `encode_measurement`. Rather,
/// its input is an aggregation of truncated measurements.
rozbb marked this conversation as resolved.
Show resolved Hide resolved
fn decode_result(
&self,
data: &[Self::Field],
Expand Down
350 changes: 346 additions & 4 deletions src/flp/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt};
use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval};
use crate::flp::{FlpError, Gadget, Type};
use crate::polynomial::poly_range_check;
use crate::vdaf::prio3::ilog2;
use std::convert::TryInto;
use std::fmt::{self, Debug};
use std::marker::PhantomData;
Expand Down Expand Up @@ -471,6 +472,232 @@ where
}
}

/// The multihot counter data type. Each measurement is a set of integers in `[0, length)`, of size
/// at most `max_weight`, and the aggregate is a histogram counting the number of occurrences of
/// each integer across all measurements.
rozbb marked this conversation as resolved.
Show resolved Hide resolved
#[derive(PartialEq, Eq)]
pub struct MultihotCountVec<F, S> {
length: usize,
max_weight: usize,
chunk_length: usize,
gadget_calls: usize,
phantom: PhantomData<(F, S)>,
}

impl<F: FftFriendlyFieldElement, S> Debug for MultihotCountVec<F, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MultihotCountVec")
.field("length", &self.length)
.field("max_weight", &self.max_weight)
.field("chunk_length", &self.chunk_length)
.finish()
}
}

impl<F: FftFriendlyFieldElement, S: ParallelSumGadget<F, Mul<F>>> MultihotCountVec<F, S> {
/// Return a new [`MultihotCountVec`] type with the given number of buckets.
pub fn new(
num_buckets: usize,
max_weight: usize,
chunk_length: usize,
) -> Result<Self, FlpError> {
if num_buckets >= u32::MAX as usize {
return Err(FlpError::Encode(
"invalid num_buckets: exceeds maximum permitted".to_string(),
));
}
if num_buckets == 0 {
return Err(FlpError::InvalidParameter(
"num_buckets cannot be zero".to_string(),
));
}
if chunk_length == 0 {
return Err(FlpError::InvalidParameter(
"chunk_length cannot be zero".to_string(),
));
}
if max_weight == 0 {
return Err(FlpError::InvalidParameter(
"max_weight cannot be zero".to_string(),
));
}

// The bitlength of a measurement is the number of buckets plus the bitlength of the max
// weight
let meas_length = {
let bits_for_weight = ilog2(max_weight) as usize + 1;
num_buckets + bits_for_weight
};

// Gadget calls is ⌈meas_length / chunk_length⌉
let gadget_calls = (meas_length + chunk_length - 1) / chunk_length;

Ok(Self {
length: num_buckets,
max_weight,
chunk_length,
gadget_calls,
phantom: PhantomData,
})
}
}

// Cannot autoderive clone because it requires F and S to be Clone, which they're not in general
impl<F, S> Clone for MultihotCountVec<F, S> {
fn clone(&self) -> Self {
Self {
length: self.length,
max_weight: self.max_weight,
chunk_length: self.chunk_length,
gadget_calls: self.gadget_calls,
phantom: self.phantom,
}
}
}

impl<F, S> Type for MultihotCountVec<F, S>
where
F: FftFriendlyFieldElement,
S: ParallelSumGadget<F, Mul<F>> + Eq + 'static,
{
type Measurement = Vec<bool>;
type AggregateResult = Vec<F::Integer>;
type Field = F;

fn encode_measurement(&self, measurement: &Vec<bool>) -> Result<Vec<F>, FlpError> {
let weight_reported: usize = measurement.iter().filter(|bit| **bit).count();

if measurement.len() != self.length {
return Err(FlpError::Encode(format!(
"unexpected measurement length: got {}; want {}",
measurement.len(),
self.length
)));
}
if weight_reported > self.max_weight {
return Err(FlpError::Encode(format!(
"unexpected measurement weight: got {}; want ≤{}",
weight_reported, self.max_weight
)));
}

// Convert bool vector to field elems
let multihot_vec: Vec<F> = measurement
.iter()
// We can unwrap because any Integer type can cast from bool
.map(|bit| F::from(F::valid_integer_try_from(*bit as usize).unwrap()))
.collect();

// Encode the measurement weight in binary (actually, the weight plus some offset)
let offset_weight_bits = {
let bits_for_weight = ilog2(self.max_weight) as usize + 1;
let offset = (1 << bits_for_weight) - 1 - self.max_weight;
rozbb marked this conversation as resolved.
Show resolved Hide resolved

let offset_weight_reported = F::valid_integer_try_from(offset + weight_reported)?;
F::encode_as_bitvector(offset_weight_reported, bits_for_weight)?.collect()
};

// Report the concat of the two
Ok([multihot_vec, offset_weight_bits].concat())
}

fn decode_result(
&self,
data: &[Self::Field],
_num_measurements: usize,
) -> Result<Self::AggregateResult, FlpError> {
// The aggregate is the same as the decoded result. Just convert to integers
Ok(data.iter().map(|f| F::Integer::from(*f)).collect())
}

fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
vec![Box::new(S::new(
Mul::new(self.gadget_calls),
self.chunk_length,
))]
}

fn valid(
&self,
g: &mut Vec<Box<dyn Gadget<F>>>,
input: &[F],
joint_rand: &[F],
num_shares: usize,
) -> Result<F, FlpError> {
self.valid_call_check(input, joint_rand)?;

// Check that each element of `input` is a 0 or 1.
let range_check = parallel_sum_range_checks(
&mut g[0],
input,
joint_rand[0],
self.chunk_length,
num_shares,
)?;

// Check that the elements of `input` sum to at most `max_weight`.
let count_vec = &input[..self.length];
let weight = count_vec.iter().fold(F::zero(), |a, b| a + *b);
let offset_weight_reported = F::decode_bitvector(&input[self.length..])?;

// From spec: weight_check = self.offset*shares_inv + weight - weight_reported
let weight_check = {
let bits_for_weight = ilog2(self.max_weight) as usize + 1;
let offset = F::from(F::valid_integer_try_from(
(1 << bits_for_weight) - 1 - self.max_weight,
)?);

let shares_inv = F::from(F::valid_integer_try_from(num_shares)?).inv();
offset * shares_inv + weight - offset_weight_reported
};

// Take a random linear combination of both checks.
let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * weight_check;
Ok(out)
}

// Truncates the measurement, removing extra data that was necessary for validity (here, the
// encoded weight), but not important for aggregation
fn truncate(&self, input: Vec<Self::Field>) -> Result<Vec<Self::Field>, FlpError> {
// Cut off the encoded weight
Ok(input[..self.length].to_vec())
}

// The length in field elements of the encoded input returned by [`Self::encode_measurement`].
fn input_len(&self) -> usize {
let bits_for_weight = ilog2(self.max_weight) as usize + 1;
self.length + bits_for_weight
}

fn proof_len(&self) -> usize {
(self.chunk_length * 2) + 2 * ((1 + self.gadget_calls).next_power_of_two() - 1) + 1
}

fn verifier_len(&self) -> usize {
2 + self.chunk_length * 2
}

// The length of the truncated output (i.e., the output of [`Type::truncate`]).
fn output_len(&self) -> usize {
self.length
}

// The number of random values needed in the validity checks
fn joint_rand_len(&self) -> usize {
2
}

fn prove_rand_len(&self) -> usize {
self.chunk_length * 2
}

fn query_rand_len(&self) -> usize {
// TODO: this will need to be increase once draft-10 is implemented and more randomness is
// necessary due to random linear combination computations
1
}
}

/// A sequence of integers in range `[0, 2^bits)`. This type uses a neat trick from [[BBCG+19],
/// Corollary 4.9] to reduce the proof size to roughly the square root of the input size.
///
Expand Down Expand Up @@ -685,13 +912,13 @@ pub(crate) fn call_gadget_on_vec_entries<F: FftFriendlyFieldElement>(
input: &[F],
rnd: F,
) -> Result<F, FlpError> {
let mut range_check = F::zero();
let mut comb = F::zero();
let mut r = rnd;
for chunk in input.chunks(1) {
range_check += r * g.call(chunk)?;
comb += r * g.call(chunk)?;
r *= rnd;
}
Ok(range_check)
Ok(comb)
}

/// Given a vector `data` of field elements which should contain exactly one entry, return the
Expand Down Expand Up @@ -776,7 +1003,9 @@ pub(crate) fn parallel_sum_range_checks<F: FftFriendlyFieldElement>(
#[cfg(test)]
mod tests {
use super::*;
use crate::field::{random_vector, Field64 as TestField, FieldElement};
use crate::field::{
random_vector, Field64 as TestField, FieldElement, FieldElementWithInteger,
};
use crate::flp::gadgets::ParallelSum;
#[cfg(feature = "multithreaded")]
use crate::flp::gadgets::ParallelSumMultithreaded;
Expand Down Expand Up @@ -957,6 +1186,119 @@ mod tests {
);
}

fn test_multihot<F, S>(constructor: F)
where
F: Fn(usize, usize, usize) -> Result<MultihotCountVec<TestField, S>, FlpError>,
S: ParallelSumGadget<TestField, Mul<TestField>> + Eq + 'static,
{
const NUM_SHARES: usize = 3;

// Chunk size for our range check gadget
let chunk_size = 2;

// Our test is on multihot vecs of length 3, with max weight 2
let num_buckets = 3;
let max_weight = 2;

let multihot_instance = constructor(num_buckets, max_weight, chunk_size).unwrap();
let zero = TestField::zero();
let one = TestField::one();
let nine = TestField::from(9);

let encoded_weight_plus_offset = |weight| {
let bits_for_weight = ilog2(max_weight) as usize + 1;
let offset = (1 << bits_for_weight) - 1 - max_weight;
TestField::encode_as_bitvector(
<TestField as FieldElementWithInteger>::Integer::try_from(weight + offset).unwrap(),
bits_for_weight,
)
.unwrap()
.collect::<Vec<TestField>>()
};

assert_eq!(
multihot_instance
.encode_measurement(&vec![true, true, false])
.unwrap(),
[&[one, one, zero], &*encoded_weight_plus_offset(2)].concat(),
);
assert_eq!(
multihot_instance
.encode_measurement(&vec![false, true, true])
.unwrap(),
[&[zero, one, one], &*encoded_weight_plus_offset(2)].concat(),
);

// Round trip
assert_eq!(
multihot_instance
.decode_result(
&multihot_instance
.truncate(
multihot_instance
.encode_measurement(&vec![false, true, true])
.unwrap()
)
.unwrap(),
1
)
.unwrap(),
[0, 1, 1]
);

// Test valid inputs with weights 0, 1, and 2
FlpTest::expect_valid::<NUM_SHARES>(
&multihot_instance,
&multihot_instance
.encode_measurement(&vec![true, false, false])
.unwrap(),
&[one, zero, zero],
);

FlpTest::expect_valid::<NUM_SHARES>(
&multihot_instance,
&multihot_instance
.encode_measurement(&vec![false, true, true])
.unwrap(),
&[zero, one, one],
);

FlpTest::expect_valid::<NUM_SHARES>(
&multihot_instance,
&multihot_instance
.encode_measurement(&vec![false, false, false])
.unwrap(),
&[zero, zero, zero],
);

// Test invalid inputs.

// Not binary
FlpTest::expect_invalid::<NUM_SHARES>(
&multihot_instance,
&[&[zero, zero, nine], &*encoded_weight_plus_offset(1)].concat(),
);
// Wrong weight
FlpTest::expect_invalid::<NUM_SHARES>(
&multihot_instance,
&[&[zero, zero, one], &*encoded_weight_plus_offset(2)].concat(),
);
// Weight too high. This actually panics because weight + offset cannot fit into a bitvector
// of the correct length. In other words, being out-of-range requires the prover to lie
// about their weight, which is tested above
/*
FlpTest::expect_invalid::<NUM_SHARES>(
&multihot_instance,
&[&[one, one, one], &*encoded_weight_plus_offset(3)].concat(),
);
*/
rozbb marked this conversation as resolved.
Show resolved Hide resolved
}

#[test]
fn test_multihot_serial() {
test_multihot(MultihotCountVec::<TestField, ParallelSum<TestField, Mul<TestField>>>::new);
}

fn test_sum_vec<F, S>(f: F)
where
F: Fn(usize, usize, usize) -> Result<SumVec<TestField, S>, FlpError>,
Expand Down
Loading