Skip to content

Commit

Permalink
Add poc changes for
Browse files Browse the repository at this point in the history
  • Loading branch information
albertpl committed Oct 3, 2023
1 parent 402438c commit 8b74ffc
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 24 deletions.
22 changes: 22 additions & 0 deletions poc/flp_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,12 @@ class SumVec(Valid):
JOINT_RAND_LEN = 1
OUTPUT_LEN = None # Set by constructor

@classmethod
def with_field(SumVec, field: field.FftField):
class SumVecWithField(SumVec):
Field = field
return SumVecWithField

def __init__(self, length, bits, chunk_length):
"""
Instantiate the `SumVec` circuit for measurements with `length`
Expand Down Expand Up @@ -829,6 +835,20 @@ def decode(self, output, num_measurements):
return total // num_measurements


def _test_sumvec_with_field(f: field.FftField):
cls = SumVec.with_field(f)
assert cls.Field == f
flp = FlpGeneric(cls(2, 4, 1))
# Roundtrip test with no proof generated.
for meas in [[1, 2], [3, 4], [5, 6], [7, 8]]:
assert meas == flp.decode(flp.truncate(flp.encode(meas)), 1)


def test_sumvec_with_field():
for f in [field.Field64, field.Field96, field.Field128]:
_test_sumvec_with_field(f)


def test():
flp = FlpGeneric(Count())
test_flp_generic(flp, [
Expand Down Expand Up @@ -874,6 +894,8 @@ def test():
(flp.encode(0), True),
])

test_sumvec_with_field()


if __name__ == '__main__':
test()
158 changes: 134 additions & 24 deletions poc/vdaf_prio3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import xof
from common import (ERR_INPUT, ERR_VERIFY, TEST_VECTOR, Unsigned, byte, concat,
front, vec_add, vec_sub, zeros)
from field import FftField, Field64, Field96, Field128
from vdaf import Vdaf, test_vdaf

USAGE_MEAS_SHARE = 1
Expand All @@ -31,6 +32,7 @@ class Prio3(Vdaf):
RAND_SIZE = None # Computed from `Xof.SEED_SIZE` and `SHARES`
ROUNDS = 1
SHARES = None # A number between `[2, 256)` set later
PROOFS = 1 # Number of independent proofs

# Types required by `Vdaf`
Measurement = Flp.Measurement
Expand Down Expand Up @@ -73,7 +75,7 @@ def is_valid(agg_param, previous_agg_params):
def prep_init(Prio3, verify_key, agg_id, _agg_param,
nonce, public_share, input_share):
k_joint_rand_parts = public_share
(meas_share, proof_share, k_blind) = \
(meas_share, proof_shares, k_blind) = \
Prio3.expand_input_share(agg_id, input_share)
out_share = Prio3.Flp.truncate(meas_share)

Expand All @@ -86,15 +88,24 @@ def prep_init(Prio3, verify_key, agg_id, _agg_param,
k_joint_rand_parts[agg_id] = k_joint_rand_part
k_corrected_joint_rand = Prio3.joint_rand_seed(
k_joint_rand_parts)
joint_rand = Prio3.joint_rand(k_corrected_joint_rand)
joint_rands = Prio3.joint_rands(k_corrected_joint_rand)

# Query the measurement and proof share.
query_rand = Prio3.query_rand(verify_key, nonce)
verifier_share = Prio3.Flp.query(meas_share,
proof_share,
query_rand,
joint_rand,
Prio3.SHARES)
query_rands = Prio3.query_rand(verify_key, nonce)
verifier_share = []
for _ in range(Prio3.PROOFS):
proof_share, proof_shares = front(
Prio3.Flp.PROOF_LEN, proof_shares)
query_rand, query_rands = front(
Prio3.Flp.QUERY_RAND_LEN, query_rands)
if Prio3.Flp.JOINT_RAND_LEN > 0:
joint_rand, joint_rands = front(
Prio3.Flp.JOINT_RAND_LEN, joint_rands)
verifier_share += Prio3.Flp.query(meas_share,
proof_share,
query_rand,
joint_rand,
Prio3.SHARES)

prep_state = (out_share, k_corrected_joint_rand)
prep_share = (verifier_share, k_joint_rand_part)
Expand All @@ -115,16 +126,19 @@ def prep_next(Prio3, prep, prep_msg):
@classmethod
def prep_shares_to_prep(Prio3, _agg_param, prep_shares):
# Unshard the verifier shares into the verifier message.
verifier = Prio3.Flp.Field.zeros(Prio3.Flp.VERIFIER_LEN)
verifiers = Prio3.Flp.Field.zeros(
Prio3.Flp.VERIFIER_LEN * Prio3.PROOFS)
k_joint_rand_parts = []
for (verifier_share, k_joint_rand_part) in prep_shares:
verifier = vec_add(verifier, verifier_share)
verifiers = vec_add(verifiers, verifier_share)
if Prio3.Flp.JOINT_RAND_LEN > 0:
k_joint_rand_parts.append(k_joint_rand_part)

# Verify that the proof is well-formed and the input is valid.
if not Prio3.Flp.decide(verifier):
raise ERR_VERIFY # proof verifier check failed
# Verify that all the proofs are well-formed and the input is valid.
for _ in range(Prio3.PROOFS):
verifier, verifiers = front(Prio3.Flp.VERIFIER_LEN, verifiers)
if not Prio3.Flp.decide(verifier):
raise ERR_VERIFY # proof verifier check failed

# Combine the joint randomness parts computed by the
# Aggregators into the true joint randomness seed. This is
Expand Down Expand Up @@ -173,8 +187,12 @@ def shard_without_joint_rand(Prio3, meas, seeds):
)

# Generate the proof and shard it into proof shares.
prove_rand = Prio3.prove_rand(k_prove)
leader_proof_share = Prio3.Flp.prove(meas, prove_rand, [])
prove_rands = Prio3.prove_rands(k_prove)
leader_proof_share = []
for _ in range(Prio3.PROOFS):
prove_rand, prove_rands = front(
Prio3.Flp.PROVE_RAND_LEN, prove_rands)
leader_proof_share += Prio3.Flp.prove(meas, prove_rand, [])
for j in range(Prio3.SHARES-1):
leader_proof_share = vec_sub(
leader_proof_share,
Expand Down Expand Up @@ -230,10 +248,16 @@ def shard_with_joint_rand(Prio3, meas, nonce, seeds):
0, k_leader_blind, leader_meas_share, nonce))

# Generate the proof and shard it into proof shares.
prove_rand = Prio3.prove_rand(k_prove)
joint_rand = Prio3.joint_rand(
prove_rands = Prio3.prove_rands(k_prove)
joint_rands = Prio3.joint_rands(
Prio3.joint_rand_seed(k_joint_rand_parts))
leader_proof_share = Prio3.Flp.prove(meas, prove_rand, joint_rand)
leader_proof_share = []
for _ in range(Prio3.PROOFS):
prove_rand, prove_rands = front(
Prio3.Flp.PROVE_RAND_LEN, prove_rands)
joint_rand, joint_rands = front(
Prio3.Flp.JOINT_RAND_LEN, joint_rands)
leader_proof_share += Prio3.Flp.prove(meas, prove_rand, joint_rand)
for j in range(Prio3.SHARES-1):
leader_proof_share = vec_sub(
leader_proof_share,
Expand Down Expand Up @@ -274,7 +298,7 @@ def helper_proof_share(Prio3, agg_id, k_share):
k_share,
Prio3.domain_separation_tag(USAGE_PROOF_SHARE),
byte(agg_id),
Prio3.Flp.PROOF_LEN,
Prio3.Flp.PROOF_LEN * Prio3.PROOFS,
)

@classmethod
Expand All @@ -286,13 +310,13 @@ def expand_input_share(Prio3, agg_id, input_share):
return (meas_share, proof_share, k_blind)

@classmethod
def prove_rand(Prio3, k_prove):
def prove_rands(Prio3, k_prove):
return Prio3.Xof.expand_into_vec(
Prio3.Flp.Field,
k_prove,
Prio3.domain_separation_tag(USAGE_PROVE_RANDOMNESS),
b'',
Prio3.Flp.PROVE_RAND_LEN,
Prio3.Flp.PROVE_RAND_LEN * Prio3.PROOFS,
)

@classmethod
Expand All @@ -302,7 +326,7 @@ def query_rand(Prio3, verify_key, nonce):
verify_key,
Prio3.domain_separation_tag(USAGE_QUERY_RANDOMNESS),
nonce,
Prio3.Flp.QUERY_RAND_LEN,
Prio3.Flp.QUERY_RAND_LEN * Prio3.PROOFS,
)

@classmethod
Expand All @@ -323,14 +347,14 @@ def joint_rand_seed(Prio3, k_joint_rand_parts):
)

@classmethod
def joint_rand(Prio3, k_joint_rand_seed):
def joint_rands(Prio3, k_joint_rand_seed):
"""Derive the joint randomness from its seed."""
return Prio3.Xof.expand_into_vec(
Prio3.Flp.Field,
k_joint_rand_seed,
Prio3.domain_separation_tag(USAGE_JOINT_RANDOMNESS),
b'',
Prio3.Flp.JOINT_RAND_LEN,
Prio3.Flp.JOINT_RAND_LEN * Prio3.PROOFS,
)

@classmethod
Expand Down Expand Up @@ -370,6 +394,7 @@ def test_vec_encode_input_share(Prio3, input_share):
(meas_share, proof_share, k_blind) = input_share
encoded = bytes()
if type(meas_share) == list and type(proof_share) == list: # leader
assert len(proof_share) == Prio3.Flp.PROOF_LEN * Prio3.PROOFS
encoded += Prio3.Flp.Field.encode_vec(meas_share)
encoded += Prio3.Flp.Field.encode_vec(proof_share)
elif type(meas_share) == bytes and type(proof_share) == bytes: # helper
Expand All @@ -394,6 +419,7 @@ def test_vec_encode_agg_share(Prio3, agg_share):
def test_vec_encode_prep_share(Prio3, prep_share):
(verifier_share, k_joint_rand_part) = prep_share
encoded = bytes()
assert len(verifier_share) == Prio3.Flp.VERIFIER_LEN * Prio3.PROOFS
encoded += Prio3.Flp.Field.encode_vec(verifier_share)
if k_joint_rand_part != None: # joint randomness used
encoded += k_joint_rand_part
Expand Down Expand Up @@ -483,6 +509,41 @@ class Prio3HistogramWithLength(Prio3Histogram):
return Prio3HistogramWithLength


class Prio3SumVecWithMultiproof(Prio3SumVec):
ID = 0xFFFFFFFF # TBD

# Operational parameters.
test_vec_name = 'Prio3SumVecWithMultiproof'

@staticmethod
def is_recommended(num_proofs: Unsigned, field: FftField) -> bool:
# To be confirmed
if field == Field64:
return num_proofs >= 2
elif field == Field96:
return num_proofs >= 2
elif field == Field128:
return num_proofs >= 1
return False

@classmethod
def with_params(cls,
length: Unsigned,
bits: Unsigned,
chunk_length: Unsigned,
num_proofs: Unsigned,
field: FftField):
if not cls.is_recommended(num_proofs, field):
raise ERR_INPUT

valid_cls = flp_generic.SumVec.with_field(field)

class Prio3SumVecWithMultiproofAndParams(cls):
PROOFS = num_proofs
Flp = flp_generic.FlpGeneric(valid_cls(length, bits, chunk_length))
return Prio3SumVecWithMultiproofAndParams


##
# TESTS
#
Expand All @@ -506,6 +567,53 @@ class TestPrio3AverageWithBits(TestPrio3Average):
return TestPrio3AverageWithBits


def _test_prio3sumvec(num_proofs: Unsigned, field: FftField):
cls = Prio3SumVecWithMultiproof.with_params(
10, 8, 9, num_proofs=num_proofs, field=field).with_shares(2)

assert cls.ID == 0xFFFFFFFF
assert cls.PROOFS == num_proofs

test_vdaf(
cls,
None,
[[1, 61, 86, 61, 23, 0, 255, 3, 2, 1]],
[1, 61, 86, 61, 23, 0, 255, 3, 2, 1]
)
test_vdaf(
cls,
None,
[
list(range(10)),
[1] * 10,
[255] * 10
],
list(range(256, 266)),
print_test_vec=TEST_VECTOR,
)
cls = Prio3SumVec.with_params(3, 16, 7).with_shares(3)
test_vdaf(
cls,
None,
[
[10000, 32000, 9],
[19342, 19615, 3061],
[15986, 24671, 23910]
],
[45328, 76286, 26980],
print_test_vec=TEST_VECTOR,
test_vec_instance=1,
)


def test_prio3sumvec_with_multiproof():
for n in range(1, 5):
for f in [Field64, Field96, Field128]:
if not Prio3SumVecWithMultiproof.is_recommended(n, f):
continue
_test_prio3sumvec(num_proofs=n, field=f)


if __name__ == '__main__':
num_shares = 2 # Must be in range `[2, 256)`

Expand Down Expand Up @@ -601,3 +709,5 @@ class TestPrio3AverageWithBits(TestPrio3Average):
# otherwise.
assert cls.is_valid(None, set([]))
assert not cls.is_valid(None, set([None]))

test_prio3sumvec_with_multiproof()

0 comments on commit 8b74ffc

Please sign in to comment.