diff --git a/draft-irtf-cfrg-vdaf.md b/draft-irtf-cfrg-vdaf.md index fe6dd2bd..96c139bc 100644 --- a/draft-irtf-cfrg-vdaf.md +++ b/draft-irtf-cfrg-vdaf.md @@ -1047,6 +1047,22 @@ derive per-report randomness for verification of the computation. See ## Preparation {#sec-daf-prepare} +~~~ + Aggregator 0 Aggregator 1 Aggregator SHARES-1 + ============ ============ =================== + + input_share_0 input_share_1 input_share_[SHARES-1] + | | ... | + V V V + +-----------+ +-----------+ +-----------+ + | prep | | prep | | prep | + +-----------+ +-----------+ +-----------+ + | | ... | + V V V + out_share_0 out_share_1 out_share_[SHARES-1] +~~~ +{: #daf-prep-flow title="Illustration of preparation."} + Once an Aggregator has received the public share and its input share, the next step is to prepare the input share for aggregation. This is accomplished using the preparation algorithm: @@ -1202,31 +1218,19 @@ def run_daf( Pre-conditions: - - `type(agg_param) == daf.AggParam` - - `type(measurement) == daf.Measurement` for each - `measurement` in `measurements` - - `len(nonce) == daf.NONCE_SIZE` for each `nonce` in `nonces` - `len(nonces) == len(measurements)` + - `all(len(nonce) == daf.NONCE_SIZE for nonce in nonces)` """ - if any(len(nonce) != daf.NONCE_SIZE for nonce in nonces): - raise ValueError("incorrect nonce size") - if len(nonces) != len(measurements): - raise ValueError( - "measurements and nonces lists have different lengths" - ) - agg_shares: list[AggShare] agg_shares = [daf.agg_init(agg_param) for _ in range(daf.SHARES)] for (measurement, nonce) in zip(measurements, nonces): - # Each Client shards its measurement into input shares and - # distributes them among the Aggregators. + # Sharding rand = gen_rand(daf.RAND_SIZE) (public_share, input_shares) = \ daf.shard(ctx, measurement, nonce, rand) - # Each Aggregator computes its output share from its input - # share and aggregates it. + # Preparation, aggregation for j in range(daf.SHARES): out_share = daf.prep(ctx, j, agg_param, nonce, public_share, input_shares[j]) @@ -1234,7 +1238,7 @@ def run_daf( agg_shares[j], out_share) - # Collector unshards the aggregate result. + # Unsharding num_measurements = len(measurements) agg_result = daf.unshard(agg_param, agg_shares, num_measurements) @@ -1491,27 +1495,17 @@ def run_vdaf( - `len(nonces) == len(measurements)` - `all(len(nonce) == vdaf.NONCE_SIZE for nonce in nonces)` """ - - if len(verify_key) != vdaf.VERIFY_KEY_SIZE: - raise ValueError("incorrect verify_key size") - if any(len(nonce) != vdaf.NONCE_SIZE for nonce in nonces): - raise ValueError("incorrect nonce size") - if len(nonces) != len(measurements): - raise ValueError( - "measurements and nonces lists have different lengths" - ) - agg_shares = [vdaf.agg_init(agg_param) for _ in range(vdaf.SHARES)] for (nonce, measurement) in zip(nonces, measurements): assert len(nonce) == vdaf.NONCE_SIZE - # Each Client shards its measurement into input shares. + # Sharding rand = gen_rand(vdaf.RAND_SIZE) (public_share, input_shares) = \ vdaf.shard(ctx, measurement, nonce, rand) - # Each Aggregator initializes its preparation state. + # Initialize preparation prep_states = [] outbound_prep_shares = [] for j in range(vdaf.SHARES): @@ -1523,7 +1517,7 @@ def run_vdaf( prep_states.append(state) outbound_prep_shares.append(share) - # Aggregators complete preparation. + # Complete preparation for i in range(vdaf.ROUNDS - 1): prep_msg = vdaf.prep_shares_to_prep(ctx, agg_param, @@ -1540,7 +1534,7 @@ def run_vdaf( agg_param, outbound_prep_shares) - # Each Aggregator computes and aggregates its output share. + # Aggregation for j in range(vdaf.SHARES): out_share = vdaf.prep_next(ctx, prep_states[j], prep_msg) assert not isinstance(out_share, tuple) @@ -1548,7 +1542,7 @@ def run_vdaf( agg_shares[j], out_share) - # Collector unshards the aggregate. + # Unsharding num_measurements = len(measurements) agg_result = vdaf.unshard(agg_param, agg_shares, num_measurements) @@ -1591,36 +1585,8 @@ State transitions are made when the state is acted upon by the Aggregator's local inputs and/or messages sent by its co-Aggregators. The initial state is `Start`. The terminal states are: `Rejected`, indicating that the report cannot be processed any further; and `Finished(out_share)`, indicating that the -Aggregator has recovered an output share `out_share`. - -~~~ python -class State: - pass - -class Start(State): - pass - -class Continued(State, Generic[PrepState]): - def __init__(self, prep_state: PrepState, prep_round: int): - self.prep_state = prep_state - self.prep_round = prep_round - - def __eq__(self, other: object) -> bool: - return isinstance(other, Continued) and \ - self.prep_state == other.prep_state and \ - self.prep_round == other.prep_round - -class Finished(State, Generic[OutShare]): - def __init__(self, out_share: OutShare): - self.out_share = out_share - - def __eq__(self, other: object) -> bool: - return isinstance(other, Finished) and \ - self.out_share == other.out_share - -class Rejected(State): - pass -~~~ +Aggregator has recovered an output share `out_share`. For completeness, we +define these states in {{topo-states}}. The methods described in this section are defined in terms of opaque byte strings. A compatible `Vdaf` MUST specify methods for encoding public shares, @@ -2013,7 +1979,7 @@ are used to encode measurements in some variants of Prio3 ({{prio3-instantiations}}). ~~~ python -def encode_into_bit_vector( +def encode_into_bit_vec( cls, val: int, bits: int) -> list[Self]: @@ -2036,7 +2002,7 @@ def encode_into_bit_vector( encoded.append(cls((val >> l) & 1)) return encoded -def decode_from_bit_vector(cls, vec: list[Self]) -> Self: +def decode_from_bit_vec(cls, vec: list[Self]) -> Self: """ Decode the field element from the bit representation, expressed as a vector of field elements `vec`. @@ -2395,7 +2361,7 @@ for various types of aggregation tasks are specified in {{prio3-instantiations}}. Test vectors for each variant can be found in {{test-vectors}}. -## Fully Linear Proof (FLP) Systems {#flp} +## Fully Linear Proofs (FLPs) {#flp} Conceptually, an FLP is a two-party protocol executed by a prover and a verifier. In actual use, however, the prover's computation is carried out by @@ -2604,7 +2570,7 @@ accepts on each proof. See {{security-multiproof}} for guidance on choosing the field size and number of proofs. -## Construction {#prio3-construction} +## Specification {#prio3-construction} This section specifies `Prio3`, an implementation of the `Vdaf` interface defined in {{vdaf}}. The parameters and types required by the `Vdaf` interface @@ -2949,22 +2915,6 @@ def prep_init( prep_share = (verifiers_share, joint_rand_part) return (prep_state, prep_share) -def prep_next( - self, - _ctx: bytes, - prep_state: Prio3PrepState[F], - prep_msg: Optional[bytes] -) -> tuple[Prio3PrepState[F], Prio3PrepShare[F]] | list[F]: - joint_rand_seed = prep_msg - (out_share, corrected_joint_rand_seed) = prep_state - - # If joint randomness was used, check that the value computed by - # the Aggregators matches the value indicated by the Client. - if joint_rand_seed != corrected_joint_rand_seed: - raise ValueError('joint randomness check failed') - - return out_share - def prep_shares_to_prep( self, ctx: bytes, @@ -2993,6 +2943,22 @@ def prep_shares_to_prep( if self.flp.JOINT_RAND_LEN > 0: joint_rand_seed = self.joint_rand_seed(ctx, joint_rand_parts) return joint_rand_seed + +def prep_next( + self, + _ctx: bytes, + prep_state: Prio3PrepState[F], + prep_msg: Optional[bytes] +) -> tuple[Prio3PrepState[F], Prio3PrepShare[F]] | list[F]: + joint_rand_seed = prep_msg + (out_share, corrected_joint_rand_seed) = prep_state + + # If joint randomness was used, check that the value computed by + # the Aggregators matches the value indicated by the Client. + if joint_rand_seed != corrected_joint_rand_seed: + raise ValueError('joint randomness check failed') + + return out_share ~~~ ### Validity of Aggregation Parameters @@ -3285,7 +3251,7 @@ struct { } Prio3AggShare; ~~~ -## FLP Construction {#flp-bbcggi19} +## FLP Specification {#flp-bbcggi19} | Parameter | Value | |:-----------------|:-------------------------------------------------| @@ -3899,25 +3865,25 @@ class Sum(Valid[int, int, F]): out.append(self.GADGETS[0].eval(self.field, [b])) range_check = self.offset * shares_inv + \ - self.field.decode_from_bit_vector(meas[:self.bits]) - \ - self.field.decode_from_bit_vector(meas[self.bits:]) + self.field.decode_from_bit_vec(meas[:self.bits]) - \ + self.field.decode_from_bit_vec(meas[self.bits:]) out.append(range_check) return out def encode(self, measurement: int) -> list[F]: encoded = [] - encoded += self.field.encode_into_bit_vector( + encoded += self.field.encode_into_bit_vec( measurement, self.bits ) - encoded += self.field.encode_into_bit_vector( + encoded += self.field.encode_into_bit_vec( measurement + self.offset.int(), self.bits ) return encoded def truncate(self, meas: list[F]) -> list[F]: - return [self.field.decode_from_bit_vector(meas[:self.bits])] + return [self.field.decode_from_bit_vec(meas[:self.bits])] def decode(self, output: list[F], _num_measurements: int) -> int: return output[0].int() @@ -4033,14 +3999,14 @@ class SumVec(Valid[list[int], list[int], F]): def encode(self, measurement: list[int]) -> list[F]: encoded = [] for val in measurement: - encoded += self.field.encode_into_bit_vector( + encoded += self.field.encode_into_bit_vec( val, self.bits) return encoded def truncate(self, meas: list[F]) -> list[F]: truncated = [] for i in range(self.length): - truncated.append(self.field.decode_from_bit_vector( + truncated.append(self.field.decode_from_bit_vec( meas[i * self.bits: (i + 1) * self.bits] )) return truncated @@ -4309,7 +4275,7 @@ class MultihotCountVec(Valid[list[int], list[int], F]): count_vec = meas[:self.length] weight = sum(count_vec, self.field(0)) weight_reported = \ - self.field.decode_from_bit_vector(meas[self.length:]) + self.field.decode_from_bit_vec(meas[self.length:]) weight_check = self.offset*shares_inv + weight - \ weight_reported @@ -4327,7 +4293,7 @@ class MultihotCountVec(Valid[list[int], list[int], F]): encoded = [] encoded += count_vec - encoded += self.field.encode_into_bit_vector( + encoded += self.field.encode_into_bit_vec( (self.offset + weight_reported).int(), self.bits_for_weight) return encoded @@ -4545,7 +4511,7 @@ fixed length. Further, the padding scheme must be non-ambiguous. For example, each input could be padded with `b"\x01"` followed by as many `b"\x00"` bytes as needed. -## Construction {#poplar1-construction} +## Specification {#poplar1-construction} This section specifies `Poplar1`, an implementation of the `Vdaf` interface ({{vdaf}}). It is defined in terms of the `Idpf` implementation of @@ -5224,7 +5190,7 @@ than is strictly needed. In particular, it may be sufficient to convey which indices from the previous execution will have their children included in the next. This would help reduce communication overhead. -## IDPF Construction {#idpf-bbcggi21} +## IDPF Specification {#idpf-bbcggi21} In this section we specify a concrete IDPF suitable for instantiating Poplar1. The constant and type definitions required by the `Idpf` interface are given in @@ -6142,6 +6108,40 @@ class QueryGadget(Gadget[F]): return wrapped_valid ~~~ +# VDAF Preparation State {#topo-states} + +This section lists the classes used to define each Aggregator's state during +VDAF preparation ({{vdaf-prep-comm}}). + +~~~ python +class State: + pass + +class Start(State): + pass + +class Continued(State, Generic[PrepState]): + def __init__(self, prep_state: PrepState, prep_round: int): + self.prep_state = prep_state + self.prep_round = prep_round + + def __eq__(self, other: object) -> bool: + return isinstance(other, Continued) and \ + self.prep_state == other.prep_state and \ + self.prep_round == other.prep_round + +class Finished(State, Generic[OutShare]): + def __init__(self, out_share: OutShare): + self.out_share = out_share + + def __eq__(self, other: object) -> bool: + return isinstance(other, Finished) and \ + self.out_share == other.out_share + +class Rejected(State): + pass +~~~ + # Test Vectors {#test-vectors} Test vectors for Prio3 ({{prio3}}) and Poplar1 ({{poplar1}}) are available at diff --git a/poc/tests/test_field.py b/poc/tests/test_field.py index e4590940..0c3aa3a4 100644 --- a/poc/tests/test_field.py +++ b/poc/tests/test_field.py @@ -38,8 +38,8 @@ def run_field_test(self, cls: type[Field]) -> None: vals = [i for i in range(15)] bits = 4 for val in vals: - encoded = cls.encode_into_bit_vector(val, bits) - self.assertTrue(cls.decode_from_bit_vector( + encoded = cls.encode_into_bit_vec(val, bits) + self.assertTrue(cls.decode_from_bit_vec( encoded).int() == val) def run_ntt_field_test(self, cls: type[NttField]) -> None: diff --git a/poc/vdaf_poc/daf.py b/poc/vdaf_poc/daf.py index a74f8d6b..f05c2b61 100644 --- a/poc/vdaf_poc/daf.py +++ b/poc/vdaf_poc/daf.py @@ -168,12 +168,10 @@ def run_daf( Pre-conditions: - - `type(agg_param) == daf.AggParam` - - `type(measurement) == daf.Measurement` for each - `measurement` in `measurements` - - `len(nonce) == daf.NONCE_SIZE` for each `nonce` in `nonces` - `len(nonces) == len(measurements)` + - `all(len(nonce) == daf.NONCE_SIZE for nonce in nonces)` """ + # REMOVE ME if any(len(nonce) != daf.NONCE_SIZE for nonce in nonces): raise ValueError("incorrect nonce size") if len(nonces) != len(measurements): @@ -185,14 +183,12 @@ def run_daf( agg_shares = [daf.agg_init(agg_param) for _ in range(daf.SHARES)] for (measurement, nonce) in zip(measurements, nonces): - # Each Client shards its measurement into input shares and - # distributes them among the Aggregators. + # Sharding rand = gen_rand(daf.RAND_SIZE) (public_share, input_shares) = \ daf.shard(ctx, measurement, nonce, rand) - # Each Aggregator computes its output share from its input - # share and aggregates it. + # Preparation, aggregation for j in range(daf.SHARES): out_share = daf.prep(ctx, j, agg_param, nonce, public_share, input_shares[j]) @@ -200,7 +196,7 @@ def run_daf( agg_shares[j], out_share) - # Collector unshards the aggregate result. + # Unsharding num_measurements = len(measurements) agg_result = daf.unshard(agg_param, agg_shares, num_measurements) diff --git a/poc/vdaf_poc/field.py b/poc/vdaf_poc/field.py index 27f7d9b2..dc04ad85 100644 --- a/poc/vdaf_poc/field.py +++ b/poc/vdaf_poc/field.py @@ -68,14 +68,14 @@ def decode_vec(cls, encoded: bytes) -> list[Self]: vec.append(cls(x)) return vec - # NOTE: The encode_into_bit_vector() and decode_from_bit_vector() + # NOTE: The encode_into_bit_vec() and decode_from_bit_vec() # methods are excerpted in the document, de-indented, as the figure # {{field-bit-rep}}. Their width should be limited to 69 columns # after de-indenting, or 73 columns before de-indenting, to avoid # warnings from xml2rfc. # =================================================================== @classmethod - def encode_into_bit_vector( + def encode_into_bit_vec( cls, val: int, bits: int) -> list[Self]: @@ -99,7 +99,7 @@ def encode_into_bit_vector( return encoded @classmethod - def decode_from_bit_vector(cls, vec: list[Self]) -> Self: + def decode_from_bit_vec(cls, vec: list[Self]) -> Self: """ Decode the field element from the bit representation, expressed as a vector of field elements `vec`. diff --git a/poc/vdaf_poc/flp_bbcggi19.py b/poc/vdaf_poc/flp_bbcggi19.py index 2d0a2565..b2b975c6 100644 --- a/poc/vdaf_poc/flp_bbcggi19.py +++ b/poc/vdaf_poc/flp_bbcggi19.py @@ -855,7 +855,7 @@ def eval( count_vec = meas[:self.length] weight = sum(count_vec, self.field(0)) weight_reported = \ - self.field.decode_from_bit_vector(meas[self.length:]) + self.field.decode_from_bit_vec(meas[self.length:]) weight_check = self.offset*shares_inv + weight - \ weight_reported @@ -873,7 +873,7 @@ def encode(self, measurement: list[int]) -> list[F]: encoded = [] encoded += count_vec - encoded += self.field.encode_into_bit_vector( + encoded += self.field.encode_into_bit_vec( (self.offset + weight_reported).int(), self.bits_for_weight) return encoded @@ -982,14 +982,14 @@ def encode(self, measurement: list[int]) -> list[F]: 'entry of measurement vector is out of range' ) - encoded += self.field.encode_into_bit_vector( + encoded += self.field.encode_into_bit_vec( val, self.bits) return encoded def truncate(self, meas: list[F]) -> list[F]: truncated = [] for i in range(self.length): - truncated.append(self.field.decode_from_bit_vector( + truncated.append(self.field.decode_from_bit_vec( meas[i * self.bits: (i + 1) * self.bits] )) return truncated @@ -1070,25 +1070,25 @@ def eval( out.append(self.GADGETS[0].eval(self.field, [b])) range_check = self.offset * shares_inv + \ - self.field.decode_from_bit_vector(meas[:self.bits]) - \ - self.field.decode_from_bit_vector(meas[self.bits:]) + self.field.decode_from_bit_vec(meas[:self.bits]) - \ + self.field.decode_from_bit_vec(meas[self.bits:]) out.append(range_check) return out def encode(self, measurement: int) -> list[F]: encoded = [] - encoded += self.field.encode_into_bit_vector( + encoded += self.field.encode_into_bit_vec( measurement, self.bits ) - encoded += self.field.encode_into_bit_vector( + encoded += self.field.encode_into_bit_vec( measurement + self.offset.int(), self.bits ) return encoded def truncate(self, meas: list[F]) -> list[F]: - return [self.field.decode_from_bit_vector(meas[:self.bits])] + return [self.field.decode_from_bit_vec(meas[:self.bits])] def decode(self, output: list[F], _num_measurements: int) -> int: return output[0].int() diff --git a/poc/vdaf_poc/vdaf.py b/poc/vdaf_poc/vdaf.py index 4d9017dd..4f841dd2 100644 --- a/poc/vdaf_poc/vdaf.py +++ b/poc/vdaf_poc/vdaf.py @@ -265,7 +265,7 @@ def run_vdaf( - `len(nonces) == len(measurements)` - `all(len(nonce) == vdaf.NONCE_SIZE for nonce in nonces)` """ - + # REMOVE ME if len(verify_key) != vdaf.VERIFY_KEY_SIZE: raise ValueError("incorrect verify_key size") if any(len(nonce) != vdaf.NONCE_SIZE for nonce in nonces): @@ -280,12 +280,12 @@ def run_vdaf( for (nonce, measurement) in zip(nonces, measurements): assert len(nonce) == vdaf.NONCE_SIZE - # Each Client shards its measurement into input shares. + # Sharding rand = gen_rand(vdaf.RAND_SIZE) (public_share, input_shares) = \ vdaf.shard(ctx, measurement, nonce, rand) - # Each Aggregator initializes its preparation state. + # Initialize preparation prep_states = [] outbound_prep_shares = [] for j in range(vdaf.SHARES): @@ -297,7 +297,7 @@ def run_vdaf( prep_states.append(state) outbound_prep_shares.append(share) - # Aggregators complete preparation. + # Complete preparation for i in range(vdaf.ROUNDS - 1): prep_msg = vdaf.prep_shares_to_prep(ctx, agg_param, @@ -314,7 +314,7 @@ def run_vdaf( agg_param, outbound_prep_shares) - # Each Aggregator computes and aggregates its output share. + # Aggregation for j in range(vdaf.SHARES): out_share = vdaf.prep_next(ctx, prep_states[j], prep_msg) assert not isinstance(out_share, tuple) @@ -322,7 +322,7 @@ def run_vdaf( agg_shares[j], out_share) - # Collector unshards the aggregate. + # Unsharding num_measurements = len(measurements) agg_result = vdaf.unshard(agg_param, agg_shares, num_measurements) diff --git a/poc/vdaf_poc/vdaf_prio3.py b/poc/vdaf_poc/vdaf_prio3.py index bbfac282..bdbbe057 100644 --- a/poc/vdaf_poc/vdaf_prio3.py +++ b/poc/vdaf_poc/vdaf_prio3.py @@ -124,7 +124,7 @@ def is_valid( previous_agg_params: list[None]) -> bool: return len(previous_agg_params) == 0 - # NOTE: The prep_init(), prep_next(), and prep_shares_to_prep() + # NOTE: The prep_init(), prep_shares_to_prep(), and prep_next() # methods are excerpted in the document, de-indented, as figure # {{prio3-prep-state}}. Their width should be limited to 69 columns # after de-indenting, or 73 columns before de-indenting, to avoid @@ -183,22 +183,6 @@ def prep_init( prep_share = (verifiers_share, joint_rand_part) return (prep_state, prep_share) - def prep_next( - self, - _ctx: bytes, - prep_state: Prio3PrepState[F], - prep_msg: Optional[bytes] - ) -> tuple[Prio3PrepState[F], Prio3PrepShare[F]] | list[F]: - joint_rand_seed = prep_msg - (out_share, corrected_joint_rand_seed) = prep_state - - # If joint randomness was used, check that the value computed by - # the Aggregators matches the value indicated by the Client. - if joint_rand_seed != corrected_joint_rand_seed: - raise ValueError('joint randomness check failed') - - return out_share - def prep_shares_to_prep( self, ctx: bytes, @@ -228,6 +212,22 @@ def prep_shares_to_prep( joint_rand_seed = self.joint_rand_seed(ctx, joint_rand_parts) return joint_rand_seed + def prep_next( + self, + _ctx: bytes, + prep_state: Prio3PrepState[F], + prep_msg: Optional[bytes] + ) -> tuple[Prio3PrepState[F], Prio3PrepShare[F]] | list[F]: + joint_rand_seed = prep_msg + (out_share, corrected_joint_rand_seed) = prep_state + + # If joint randomness was used, check that the value computed by + # the Aggregators matches the value indicated by the Client. + if joint_rand_seed != corrected_joint_rand_seed: + raise ValueError('joint randomness check failed') + + return out_share + # NOTE: Methods `agg_init()`, `agg_update()`, and `merge()` are # excerpted in the document, de-indented, as figure # {{prio3-out2agg}}. The width should be limited to 69 columns after