Skip to content

Commit

Permalink
Implement Prio3MultiHotHistogram that checks bounded number of 1s
Browse files Browse the repository at this point in the history
Add a reference implementation of the idea in issue #287, which
needs a new FLP 'MultiHotHistogram' and a new Prio3 type
'Prio3MultiHotHistogram' that checks a Client's measurement has a
bounded number of 1s.
  • Loading branch information
junyechen1996 committed Oct 12, 2023
1 parent 2ed6a6b commit 42443c5
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 1 deletion.
159 changes: 158 additions & 1 deletion poc/flp_generic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A generic FLP based on {{BBCGGI19}}, Theorem 4.3."""

import copy
import math

import field
from common import ERR_ABORT, ERR_INPUT, Unsigned, Vec, next_power_of_2
Expand Down Expand Up @@ -574,7 +575,7 @@ class Histogram(Valid):

def __init__(self, length, chunk_length):
"""
Instantiate an instace of the `Histogram` circuit with the given
Instantiate an instance of the `Histogram` circuit with the given
length and chunk_length.
"""

Expand Down Expand Up @@ -640,6 +641,140 @@ def test_vec_set_type_param(self, test_vec):
return ['length', 'chunk_length']


class MultiHotHistogram(Valid):
"""
A validity circuit that checks each Client's measurement is a bit vector,
and has at most `max_num_bits` 1s.
In order to check the Client measurement `x` has at most `max_num_bits` 1s,
we ask the Client to send an additional `num_bits_for_count` bits that
encode the number of 1s in the measurement, with an offset. Specifically:
- Let `num_bits_for_count = int(log2(max_num_bits)) + 1`.
- Let `offset = 2**num_bits_for_count - 1 - max_num_bits`.
- Client will encode `count = offset + \sum_i x_i` in `num_bits_for_count`
bits.
- We can naturally bound `count` as the following:
`0 <= count <= 2**num_bits_for_count - 1`, and therefore:
`-offset <= \sum_i x_i <= max_num_bits`.
- Since we also verify each `x_i` is a bit, we can lower bound the summation
by 0. Therefore, we will be able to verify
`0 <= \sum_i x_i <= max_num_bits`.
"""
# Operational parameters
length = None # Set by `MultiHotHistogram.with_params()`
max_num_bits = None # Set by `MultiHotHistogram.with_params()`
chunk_length = None # Set by `MultiHotHistogram.with_params()`

# Associated types
Measurement = Unsigned
AggResult = Vec[Unsigned]
Field = field.Field128

# Associated parameters
GADGETS = None # Set by `MultiHotHistogram.with_params()`
GADGET_CALLS = None # Set by `MultiHotHistogram.with_params()`
MEAS_LEN = None # Set by `MultiHotHistogram.with_params()`
JOINT_RAND_LEN = 2
OUTPUT_LEN = None # Set by `MultiHotHistogram.with_params()`

def __init__(self, length, max_num_bits, chunk_length):
"""
Instantiate an instance of the `MultiHotHistogram` circuit with the
given length, max_num_bits, and chunk_length.
"""
if length <= 0:
raise ValueError('invalid length')
if max_num_bits <= 0 or max_num_bits >= length:
raise ValueError('invalid max_num_bits')
if chunk_length <= 0:
raise ValueError('invalid chunk_length')

# Compute the number of bits to represent `max_num_bits`.
self.num_bits_for_count = int(math.log2(max_num_bits)) + 1
self.offset = self.Field(
(1 << self.num_bits_for_count) - 1 - max_num_bits
)
# Sanity check `offset + length` doesn't overflow field size,
# because in validity circuit, we will compute `offset + \sum_i x_i`.
if self.Field.MODULUS - self.offset.as_unsigned() <= length:
raise ValueError('length and max_num_bits are too large '
'for the current field size')

self.length = length
self.max_num_bits = max_num_bits
self.chunk_length = chunk_length
self.GADGETS = [ParallelSum(Mul(), chunk_length)]
# The number of bit entries are `length + num_bits_for_count`,
# so the number of gadget calls is equal to
# `ceil((length + num_bits_for_count) / chunk_length)`.
self.GADGET_CALLS = [
(length + self.num_bits_for_count + chunk_length - 1) //
chunk_length
]
self.MEAS_LEN = self.length + self.num_bits_for_count
self.OUTPUT_LEN = self.length

def eval(self, meas, joint_rand, num_shares):
self.check_valid_eval(meas, joint_rand)

# Check that each bucket is one or zero.
range_check = self.Field(0)
r = joint_rand[0]
r_power = r
shares_inv = self.Field(num_shares).inv()
for i in range(self.GADGET_CALLS[0]):
inputs = [None] * (2 * self.chunk_length)
for j in range(self.chunk_length):
index = i * self.chunk_length + j
if index < len(meas):
meas_elem = meas[index]
else:
meas_elem = self.Field(0)

inputs[j * 2] = r_power * meas_elem
inputs[j * 2 + 1] = meas_elem - shares_inv

r_power *= r

range_check += self.GADGETS[0].eval(self.Field, inputs)

# Check that `offset + \sum_{i=0}^{length-1} meas[i]` is equal to the
# value claimed by the Client.
count_check = self.offset * shares_inv
for i in range(self.length):
count_check += meas[i]
count_check -= self.Field.decode_from_bit_vector(meas[self.length:])

out = joint_rand[1] * range_check + \
joint_rand[1] ** 2 * count_check
return out

def encode(self, measurement: Vec[Unsigned]):
if len(measurement) != self.length:
raise ValueError('invalid Client measurement length')

encoded = list(map(self.Field, measurement))
# Encode the result of `offset + \sum_i measurement_i` into
# `num_bits_for_count` bits.
count = self.offset + sum(encoded, self.Field(0))
encoded += self.Field.encode_into_bit_vector(
count.as_unsigned(), self.num_bits_for_count
)
return encoded

def truncate(self, meas):
return meas[:self.length]

def decode(self, output, _num_measurements):
return [bucket_count.as_unsigned() for bucket_count in output]

def test_vec_set_type_param(self, test_vec):
test_vec['length'] = int(self.length)
test_vec['max_num_bits'] = int(self.max_num_bits)
test_vec['chunk_length'] = int(self.chunk_length)
return ['length', 'max_num_bits', 'chunk_length']


class SumVec(Valid):
# Operational parameters
length = None # Set by constructor
Expand Down Expand Up @@ -863,6 +998,28 @@ def test():
(flp.Field.rand_vec(4), False),
])

# MultiHotHistogram with length = 4, max_num_bits = 2, chunk_length = 2.
flp = FlpGeneric(MultiHotHistogram(4, 2, 2))
# Successful cases:
cases = [
(flp.encode([0, 0, 0, 0]), True),
(flp.encode([0, 1, 0, 0]), True),
(flp.encode([0, 1, 1, 0]), True),
]
# Failure cases: too many number of 1s, should fail count check.
cases += [
(
[flp.Field(1)] * i +
[flp.Field(0)] * (flp.Valid.length - i) +
flp.Field.encode_into_bit_vector(i, 2),
False
)
for i in range(flp.Valid.max_num_bits, flp.Valid.length)
]
# Failure case: pass count check but fail bit check.
cases += [(flp.encode([flp.Field.MODULUS - 1, 1, 0, 0]), False)]
test_flp_generic(flp, cases)

# SumVec with length 2, bits 4, chunk len 1.
flp = FlpGeneric(SumVec(2, 4, 1))
# Roundtrip test with no proof generated.
Expand Down
49 changes: 49 additions & 0 deletions poc/vdaf_prio3.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,34 @@ class Prio3HistogramWithLength(Prio3Histogram):
return Prio3HistogramWithLength


class Prio3MultiHotHistogram(Prio3):
# Generic types required by `Prio3`
Xof = xof.XofShake128

# Associated parameters.
VERIFY_KEY_SIZE = xof.XofShake128.SEED_SIZE
ID = 0x00000FFF # private codepoint for Prio3 types

# Operational parameters.
test_vec_name = 'Prio3MultiHotHistogram'

@classmethod
def with_params(Prio3MultiHotHistogram,
length: Unsigned,
max_num_bits: Unsigned,
chunk_length: Unsigned):
class Prio3MultiHotHistogramWithParams(Prio3MultiHotHistogram):
Flp = flp_generic.FlpGeneric(flp_generic.MultiHotHistogram(
length, max_num_bits, chunk_length
))
return Prio3MultiHotHistogramWithParams


##
# TESTS
#


class TestPrio3Average(Prio3):
"""
A Prio3 instantiation to test use of num_measurements in the Valid
Expand Down Expand Up @@ -594,6 +618,31 @@ class TestPrio3AverageWithBits(TestPrio3Average):
test_vec_instance=1,
)

# Prio3MultiHotHistogram with length = 4, max_num_bits = 2,
# chunk_length = 2.
cls = Prio3MultiHotHistogram \
.with_params(4, 2, 2) \
.with_shares(num_shares)
assert cls.ID == 0x00000FFF
test_vdaf(cls, None, [[0, 0, 0, 0]], [0, 0, 0, 0])
test_vdaf(cls, None, [[0, 1, 0, 0]], [0, 1, 0, 0])
test_vdaf(cls, None, [[0, 1, 1, 0]], [0, 1, 1, 0])
test_vdaf(cls, None, [[0, 1, 1, 0], [0, 1, 0, 1]], [0, 2, 1, 1])
test_vdaf(
cls, None, [[0, 1, 1, 0]], [0, 1, 1, 0], print_test_vec=TEST_VECTOR
)
# Prio3MultiHotHistogram with length = 11, max_num_bits = 5,
# chunk_length = 3.
cls = Prio3MultiHotHistogram.with_params(11, 5, 3).with_shares(3)
test_vdaf(
cls,
None,
[[1] * 5 + [0] * 6],
[1] * 5 + [0] * 6,
print_test_vec=TEST_VECTOR,
test_vec_instance=1,
)

cls = TestPrio3Average.with_bits(3).with_shares(num_shares)
test_vdaf(cls, None, [1, 5, 1, 1, 4, 1, 3, 2], 2)

Expand Down

0 comments on commit 42443c5

Please sign in to comment.