From 0f590b23bafcf08d23a78c3978c67c38eb9868e4 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Fri, 27 Sep 2024 16:08:13 -0700 Subject: [PATCH] Add application context string Propagate changes to the reference code from a recent change that add an application context string to sharding and preparation. Explain the intended role of context: disagreement should imply failure to aggregate a report. However, the affect is different for DAFs and VDAFs, since for DAFs there is no explicit signal of rejection; disagreement would manifest as a garbled aggregate result. Point this out. Also, note in security considerations that this provides a cushion for weak parameters. Finally, update the order of the context string in the `Daf` input parameters to align with the `Vdaf`. The context string comes ahead of every input except for the VDAF verification key. Co-authored-by: David Cook --- draft-irtf-cfrg-vdaf.md | 271 +++++++++++++++++++++++-------------- poc/tests/test_daf.py | 6 +- poc/vdaf_poc/daf.py | 18 +-- poc/vdaf_poc/vdaf_prio3.py | 7 +- 4 files changed, 185 insertions(+), 117 deletions(-) diff --git a/draft-irtf-cfrg-vdaf.md b/draft-irtf-cfrg-vdaf.md index 154a20dd..92aef184 100644 --- a/draft-irtf-cfrg-vdaf.md +++ b/draft-irtf-cfrg-vdaf.md @@ -890,11 +890,11 @@ In order to protect the privacy of its measurements, a DAF Client shards its measurements into a sequence of input shares. The `shard` method is used for this purpose. -* `daf.shard(measurement: Measurement, nonce: bytes, rand: bytes) -> - tuple[PublicShare, list[InputShare]]` is the randomized sharding algorithm - run by each Client that consumes a measurement and nonce and produces a - "public share" distributed to each of the Aggregate and a corresponding - sequence of input shares, one for each Aggregator. +* `daf.shard(ctx: bytes, measurement: Measurement, nonce: bytes, rand: bytes) + -> tuple[PublicShare, list[InputShare]]` is the randomized sharding algorithm + run by each Client that consumes the application context, a measurement, and + a nonce and produces a "public share" distributed to each of the Aggregate + and a corresponding sequence of input shares, one for each Aggregator. Pre-conditions: @@ -908,6 +908,14 @@ this purpose. * The number of input shares MUST equal `SHARES`. +Sharding is bound to a specific "application context". The application context +is a string intended to uniquely identify an instance of the higher level +protocol that uses the DAF. This is intended to ensure that aggregation succeeds +only if the Clients and Aggregators agree on the application context. +(Preparation binds the application context, too; see {{sec-daf-prepare}}.) Note +that, unlike VDAFs ({{vdaf}}), there is no explicit signal of disagreement; it +may only manifest as a garbled aggregate result. + ~~~~ Client ====== @@ -937,12 +945,12 @@ Once an Aggregator has received the public share and one of the input shares, the next step is to prepare the input share for aggregation. This is accomplished using the following algorithm: -* `daf.prep(agg_id: int, agg_param: AggParam, nonce: bytes, public_share: - PublicShare, input_share: InputShare) -> OutShare` is the deterministic - preparation algorithm. It takes as input the public share and one of the - input shares generated by a Client, the Aggregator's unique identifier, the - aggregation parameter selected by the Collector, and a nonce and returns an - output share. +* `daf.prep(ctx: bytes, agg_id: int, agg_param: AggParam, nonce: bytes, + public_share: PublicShare, input_share: InputShare) -> OutShare` is the + deterministic preparation algorithm. It takes as input the public share and + one of the input shares generated by a Client, the application context, the + Aggregator's unique identifier, the aggregation parameter selected by the + Collector, and a nonce and returns an output share. Pre-conditions: @@ -1056,6 +1064,7 @@ def run_daf( AggShare, AggResult, ], + ctx: bytes, agg_param: AggParam, measurements: list[Measurement], nonces: list[bytes]) -> AggResult: @@ -1084,12 +1093,12 @@ def run_daf( # distributes them among the Aggregators. rand = gen_rand(daf.RAND_SIZE) (public_share, input_shares) = \ - daf.shard(measurement, nonce, rand) + daf.shard(ctx, measurement, nonce, rand) # Each Aggregator prepares its input share for aggregation. for j in range(daf.SHARES): out_shares[j].append( - daf.prep(j, agg_param, nonce, + daf.prep(ctx, j, agg_param, nonce, public_share, input_shares[j])) # Each Aggregator aggregates its output shares into an aggregate @@ -1174,18 +1183,19 @@ of the `Vdaf` interface specify a method of encoding the `PublicShare`, Each VDAF is identified by a unique, 32-bit integer `ID`. Identifiers for each (V)DAF specified in this document are defined in {{codepoints}}. The following -method is defined for every VDAF: +method is defined for each VDAF specified in this document: ~~~ python -def domain_separation_tag(self, usage: int) -> bytes: +def domain_separation_tag(self, usage: int, ctx: bytes) -> bytes: """ - Format domain separation tag for this VDAF with the given usage. + Format domain separation tag for this VDAF with the given + application context and usage. Pre-conditions: - `usage` in `range(2**16)` """ - return format_dst(0, self.ID, usage) + return format_dst(0, self.ID, usage) + ctx ~~~ It is used to construct a domain separation tag for an instance of `Xof` used by @@ -1196,13 +1206,14 @@ the VDAF. (See {{xof}}.) Sharding transforms a measurement and nonce into a public share and input shares as it does in DAFs (cf. {{sec-daf-shard}}): -* `vdaf.shard(measurement: Measurement, nonce: bytes, rand: bytes) -> - tuple[PublicShare, list[InputShare]]` is the randomized sharding algorithm - run by each Client that consumes a measurement and nonce and produces a - public share distributed to each of the Aggregate and a corresponding - sequence of input shares, one for each Aggregator. Depending on the VDAF, the - input shares may encode additional information used to verify the recovered - output shares (e.g., the "proof shares" in Prio3 {{prio3}}) +* `vdaf.shard(ctx: bytes, measurement: Measurement, nonce: bytes, rand: bytes) + -> tuple[PublicShare, list[InputShare]]` is the randomized sharding algorithm + run by each Client that consumes the application context, a measurement, and + a nonce and produces a public share distributed to each of the Aggregate and + a corresponding sequence of input shares, one for each Aggregator. Depending + on the VDAF, the input shares may encode additional information used to + verify the recovered output shares (e.g., the "proof shares" in Prio3 + {{prio3}}) Pre-conditions: @@ -1216,6 +1227,14 @@ as it does in DAFs (cf. {{sec-daf-shard}}): * The number of input shares MUST equal `SHARES`. +Like DAFs, sharding is bound to the application context via the `ctx` string. +Again, this is intended to ensure that aggregation succeeds only if the Clients +and Aggregators agree on the application context. Unlike DAFs, however, +disagreement on the context should manifest as a failure to validate the +report, causing the report to be rejected without garbling the aggregate +result. The application context also provides some defense-in-depth against +cross protocol attacks; see {{deep}}. + ## Preparation {#sec-vdaf-prepare} To recover and verify output shares, the Aggregators interact with one another @@ -1269,16 +1288,17 @@ share or an error."} To facilitate the preparation process, a concrete VDAF implements the following methods: -* `vdaf.prep_init(verify_key: bytes, agg_id: int, agg_param: AggParam, nonce: - bytes, public_share: PublicShare, input_share: InputShare) -> - tuple[PrepState, PrepShare]` is the deterministic preparation-state +* `vdaf.prep_init(verify_key: bytes, ctx: bytes, agg_id: int, agg_param: + AggParam, nonce: bytes, public_share: PublicShare, input_share: InputShare) + -> tuple[PrepState, PrepShare]` is the deterministic preparation-state initialization algorithm run by each Aggregator to begin processing its input share into an output share. Its inputs are the shared verification key - (`verify_key`), the Aggregator's unique identifier (`agg_id`), the - aggregation parameter (`agg_param`), the nonce provided by the environment - (`nonce`, see {{run-vdaf}}), the public share (`public_share`), and one of - the input shares generated by the Client (`input_share`). Its output is the - Aggregator's initial preparation state and initial prep share. + (`verify_key`), the application context (`ctx`), the Aggregator's unique + identifier (`agg_id`), the aggregation parameter (`agg_param`), the nonce + provided by the environment (`nonce`, see {{run-vdaf}}), the public share + (`public_share`), and one of the input shares generated by the Client + (`input_share`). Its output is the Aggregator's initial preparation state and + initial prep share. It is up to the high level protocol in which the VDAF is used to arrange for the distribution of the verification key prior to generating and processing @@ -1295,7 +1315,7 @@ methods: Client. * `nonce` MUST have length `vdaf.NONCE_SIZE`. -* `vdaf.prep_next(prep_state: PrepState, prep_msg: PrepMessage) -> +* `vdaf.prep_next(ctx: bytes, prep_state: PrepState, prep_msg: PrepMessage) -> tuple[PrepState, PrepShare] | OutShare` is the deterministic preparation-state update algorithm run by each Aggregator. It updates the Aggregator's preparation state (`prep_state`) and returns either its next @@ -1304,10 +1324,11 @@ methods: share could not be recovered. The input of this algorithm is the inbound preparation message. -* `vdaf.prep_shares_to_prep(agg_param: AggParam, prep_shares: list[PrepShare]) - -> PrepMessage` is the deterministic preparation-message pre-processing - algorithm. It combines the prep shares generated by the Aggregators in the - previous round into the prep message consumed by each in the next round. +* `vdaf.prep_shares_to_prep(ctx: bytes, agg_param: AggParam, prep_shares: + list[PrepShare]) -> PrepMessage` is the deterministic preparation-message + pre-processing algorithm. It combines the prep shares generated by the + Aggregators in the previous round into the prep message consumed by each in + the next round. In effect, each Aggregator moves through a linear state machine with `ROUNDS` states. The Aggregator enters the first state on using the initialization @@ -1389,6 +1410,7 @@ def run_vdaf( ], verify_key: bytes, agg_param: AggParam, + ctx: bytes, nonces: list[bytes], measurements: list[Measurement]) -> AggResult: """ @@ -1417,13 +1439,13 @@ def run_vdaf( # Each Client shards its measurement into input shares. rand = gen_rand(vdaf.RAND_SIZE) (public_share, input_shares) = \ - vdaf.shard(measurement, nonce, rand) + vdaf.shard(ctx, measurement, nonce, rand) # Each Aggregator initializes its preparation state. prep_states = [] outbound_prep_shares = [] for j in range(vdaf.SHARES): - (state, share) = vdaf.prep_init(verify_key, j, + (state, share) = vdaf.prep_init(verify_key, ctx, j, agg_param, nonce, public_share, @@ -1433,25 +1455,29 @@ def run_vdaf( # Aggregators recover their output shares. for i in range(vdaf.ROUNDS - 1): - prep_msg = vdaf.prep_shares_to_prep(agg_param, + prep_msg = vdaf.prep_shares_to_prep(ctx, + agg_param, outbound_prep_shares) + outbound_prep_shares = [] for j in range(vdaf.SHARES): - out = vdaf.prep_next(prep_states[j], prep_msg) + out = vdaf.prep_next(ctx, prep_states[j], prep_msg) assert isinstance(out, tuple) (prep_states[j], prep_share) = out outbound_prep_shares.append(prep_share) # The final outputs of the prepare phase are the output # shares. - prep_msg = vdaf.prep_shares_to_prep(agg_param, + prep_msg = vdaf.prep_shares_to_prep(ctx, + agg_param, outbound_prep_shares) outbound_out_shares = [] for j in range(vdaf.SHARES): - out_share = vdaf.prep_next(prep_states[j], prep_msg) + out_share = vdaf.prep_next(ctx, prep_states[j], prep_msg) assert not isinstance(out_share, tuple) outbound_out_shares.append(out_share) + out_shares.append(outbound_out_shares) # Each Aggregator aggregates its output shares into an @@ -1606,6 +1632,7 @@ following method, implemented on `Vdaf`: def ping_pong_leader_init( self, vdaf_verify_key: bytes, + ctx: bytes, agg_param: bytes, nonce: bytes, public_share: bytes, @@ -1614,6 +1641,7 @@ def ping_pong_leader_init( try: (prep_state, prep_share) = self.prep_init( vdaf_verify_key, + ctx, 0, self.decode_agg_param(agg_param), nonce, @@ -1642,6 +1670,7 @@ transition is computed using the following procedure: def ping_pong_helper_init( self, vdaf_verify_key: bytes, + ctx: bytes, agg_param: bytes, nonce: bytes, public_share: bytes, @@ -1654,6 +1683,7 @@ def ping_pong_helper_init( try: (prep_state, prep_share) = self.prep_init( vdaf_verify_key, + ctx, 1, self.decode_agg_param(agg_param), nonce, @@ -1671,6 +1701,7 @@ def ping_pong_helper_init( prep_share, ] return self.ping_pong_transition( + ctx, self.decode_agg_param(agg_param), prep_shares, prep_state, @@ -1688,14 +1719,16 @@ the prep message, and computes the next prep state of the caller: ~~~ python def ping_pong_transition( self, + ctx: bytes, agg_param: AggParam, prep_shares: list[PrepShare], prep_state: PrepState, prep_round: int) -> tuple[State, bytes]: - prep_msg = self.prep_shares_to_prep(agg_param, + prep_msg = self.prep_shares_to_prep(ctx, + agg_param, prep_shares) encoded_prep_msg = self.encode_prep_msg(prep_msg) - out = self.prep_next(prep_state, prep_msg) + out = self.prep_next(ctx, prep_state, prep_msg) if prep_round+1 == self.ROUNDS: return ( Finished(out), @@ -1720,6 +1753,7 @@ its next state transition using the function `ping_pong_leader_continued`: ~~~ python def ping_pong_leader_continued( self, + ctx: bytes, agg_param: bytes, state: State, inbound: bytes, @@ -1728,11 +1762,12 @@ def ping_pong_leader_continued( Called by the leader to start the next step of ping-ponging. """ return self.ping_pong_continued( - True, agg_param, state, inbound) + True, ctx, agg_param, state, inbound) def ping_pong_continued( self, is_leader: bool, + ctx: bytes, agg_param: bytes, state: State, inbound: bytes, @@ -1751,7 +1786,7 @@ def ping_pong_continued( state.prep_state, encoded_prep_msg, ) - out = self.prep_next(state.prep_state, prep_msg) + out = self.prep_next(ctx, state.prep_state, prep_msg) if prep_round+1 < self.ROUNDS and \ inbound_type == 1: # continue (prep_state, prep_share) = cast( @@ -1767,6 +1802,7 @@ def ping_pong_continued( if is_leader: prep_shares.reverse() return self.ping_pong_transition( + ctx, self.decode_agg_param(agg_param), prep_shares, prep_state, @@ -1789,13 +1825,14 @@ computes its next state transition using the function ~~~ python def ping_pong_helper_continued( self, + ctx: bytes, agg_param: bytes, state: State, inbound: bytes, ) -> tuple[State, Optional[bytes]]: """Called by the helper to continue ping-ponging.""" return self.ping_pong_continued( - False, agg_param, state, inbound) + False, ctx, agg_param, state, inbound) ~~~ They continue in this way until processing halts. Note that, depending on the @@ -2105,7 +2142,7 @@ class XofTurboShake128(Xof): def __init__(self, seed: bytes, dst: bytes, binder: bytes): self.l = 0 self.m = \ - to_le_bytes(len(dst), 1) + dst \ + to_le_bytes(len(dst), 2) + dst \ to_le_bytes(len(seed), 1) + seed + \ binder @@ -2158,7 +2195,7 @@ class XofFixedKeyAes128(Xof): # # Implementation note: this step can be cached across XOF # evaluations with many different seeds. - dst_length = to_le_bytes(len(dst), 1) + dst_length = to_le_bytes(len(dst), 2) self.fixed_key = TurboSHAKE128( dst_length + dst + binder, 2, @@ -2556,6 +2593,7 @@ Depending on the FLP, joint randomness may not be required. In particular, when ~~~ python def shard( self, + ctx: bytes, measurement: Measurement, nonce: bytes, rand: bytes) -> tuple[ @@ -2571,9 +2609,9 @@ def shard( meas = self.flp.encode(measurement) if self.flp.JOINT_RAND_LEN > 0: - return self.shard_with_joint_rand(meas, nonce, seeds) + return self.shard_with_joint_rand(ctx, meas, nonce, seeds) else: - return self.shard_without_joint_rand(meas, seeds) + return self.shard_without_joint_rand(ctx, meas, seeds) ~~~ {: #prio3-eval-input title="Input-distribution algorithm for Prio3."} @@ -2590,6 +2628,7 @@ i.e., when `flp.JOINT_RAND_LEN == 0`: ~~~ python def shard_without_joint_rand( self, + ctx: bytes, meas: list[F], seeds: list[bytes]) -> tuple[ Optional[list[bytes]], @@ -2602,11 +2641,11 @@ def shard_without_joint_rand( for j in range(self.SHARES - 1): leader_meas_share = vec_sub( leader_meas_share, - self.helper_meas_share(j + 1, k_helper_shares[j]), + self.helper_meas_share(ctx, j + 1, k_helper_shares[j]), ) # Generate and shard each proof into shares. - prove_rands = self.prove_rands(k_prove) + prove_rands = self.prove_rands(ctx, k_prove) leader_proofs_share = [] for _ in range(self.PROOFS): prove_rand, prove_rands = front( @@ -2616,6 +2655,7 @@ def shard_without_joint_rand( leader_proofs_share = vec_sub( leader_proofs_share, self.helper_proofs_share( + ctx, j + 1, k_helper_shares[j], ), @@ -2661,6 +2701,7 @@ i.e., for which `flp.JOINT_RAND_LEN > 0`: ~~~ python def shard_with_joint_rand( self, + ctx: bytes, meas: list[F], nonce: bytes, seeds: list[bytes]) -> tuple[ @@ -2683,18 +2724,19 @@ def shard_with_joint_rand( k_joint_rand_parts = [] for j in range(self.SHARES - 1): helper_meas_share = self.helper_meas_share( - j + 1, k_helper_shares[j]) + ctx, j + 1, k_helper_shares[j]) leader_meas_share = vec_sub(leader_meas_share, helper_meas_share) k_joint_rand_parts.append(self.joint_rand_part( - j + 1, k_helper_blinds[j], helper_meas_share, nonce)) + ctx, j + 1, k_helper_blinds[j], + helper_meas_share, nonce)) k_joint_rand_parts.insert(0, self.joint_rand_part( - 0, k_leader_blind, leader_meas_share, nonce)) + ctx, 0, k_leader_blind, leader_meas_share, nonce)) # Generate the proof and shard it into proof shares. - prove_rands = self.prove_rands(k_prove) + prove_rands = self.prove_rands(ctx, k_prove) joint_rands = self.joint_rands( - self.joint_rand_seed(k_joint_rand_parts)) + ctx, self.joint_rand_seed(ctx, k_joint_rand_parts)) leader_proofs_share = [] for _ in range(self.PROOFS): prove_rand, prove_rands = front( @@ -2710,6 +2752,7 @@ def shard_with_joint_rand( leader_proofs_share = vec_sub( leader_proofs_share, self.helper_proofs_share( + ctx, j + 1, k_helper_shares[j], ), @@ -2792,6 +2835,7 @@ The definitions of constants and a few auxiliary functions are defined in def prep_init( self, verify_key: bytes, + ctx: bytes, agg_id: int, _agg_param: None, nonce: bytes, @@ -2801,7 +2845,7 @@ def prep_init( Prio3PrepShare[F]]: k_joint_rand_parts = public_share (meas_share, proofs_share, k_blind) = \ - self.expand_input_share(agg_id, input_share) + self.expand_input_share(ctx, agg_id, input_share) out_share = self.flp.truncate(meas_share) # Compute the joint randomness. @@ -2811,14 +2855,14 @@ def prep_init( assert k_blind is not None assert k_joint_rand_parts is not None k_joint_rand_part = self.joint_rand_part( - agg_id, k_blind, meas_share, nonce) + ctx, agg_id, k_blind, meas_share, nonce) k_joint_rand_parts[agg_id] = k_joint_rand_part k_corrected_joint_rand = self.joint_rand_seed( - k_joint_rand_parts) - joint_rands = self.joint_rands(k_corrected_joint_rand) + ctx, k_joint_rand_parts) + joint_rands = self.joint_rands(ctx, k_corrected_joint_rand) # Query the measurement and proof share. - query_rands = self.query_rands(verify_key, nonce) + query_rands = self.query_rands(verify_key, ctx, nonce) verifiers_share = [] for _ in range(self.PROOFS): proof_share, proofs_share = front( @@ -2841,10 +2885,11 @@ def prep_init( return (prep_state, prep_share) def prep_next( - self, - prep_state: Prio3PrepState[F], - prep_msg: Optional[bytes], - ) -> tuple[Prio3PrepState[F], Prio3PrepShare[F]] | list[F]: + self, + _ctx: bytes, + prep_state: Prio3PrepState[F], + prep_msg: Optional[bytes] +) -> tuple[Prio3PrepState[F], Prio3PrepShare[F]] | list[F]: k_joint_rand = prep_msg (out_share, k_corrected_joint_rand) = prep_state @@ -2857,6 +2902,7 @@ def prep_next( def prep_shares_to_prep( self, + ctx: bytes, _agg_param: None, prep_shares: list[Prio3PrepShare[F]]) -> Optional[bytes]: # Unshard the verifier shares into the verifier message. @@ -2880,7 +2926,7 @@ def prep_shares_to_prep( # used in the last step. k_joint_rand = None if self.flp.JOINT_RAND_LEN > 0: - k_joint_rand = self.joint_rand_seed(k_joint_rand_parts) + k_joint_rand = self.joint_rand_seed(ctx, k_joint_rand_parts) return k_joint_rand ~~~ {: #prio3-prep-state title="Preparation state for Prio3."} @@ -2946,29 +2992,35 @@ algorithms for Prio3 in the preceding sections. The following methods are called by the sharding and preparation algorithms. ~~~ python -def helper_meas_share(self, agg_id: int, k_share: bytes) -> list[F]: +def helper_meas_share( + self, + ctx: bytes, + agg_id: int, + k_share: bytes) -> list[F]: return self.xof.expand_into_vec( self.flp.field, k_share, - self.domain_separation_tag(USAGE_MEAS_SHARE), + self.domain_separation_tag(USAGE_MEAS_SHARE, ctx), byte(agg_id), self.flp.MEAS_LEN, ) def helper_proofs_share( self, + ctx: bytes, agg_id: int, k_share: bytes) -> list[F]: return self.xof.expand_into_vec( self.flp.field, k_share, - self.domain_separation_tag(USAGE_PROOF_SHARE), + self.domain_separation_tag(USAGE_PROOF_SHARE, ctx), byte(self.PROOFS) + byte(agg_id), self.flp.PROOF_LEN * self.PROOFS, ) def expand_input_share( self, + ctx: bytes, agg_id: int, input_share: Prio3InputShare[F]) -> tuple[ list[F], @@ -2977,57 +3029,66 @@ def expand_input_share( if agg_id > 0: assert len(input_share) == 2 (k_share, k_blind) = input_share - meas_share = self.helper_meas_share(agg_id, k_share) - proofs_share = self.helper_proofs_share(agg_id, k_share) + meas_share = self.helper_meas_share(ctx, agg_id, k_share) + proofs_share = self.helper_proofs_share(ctx, agg_id, k_share) else: assert len(input_share) == 3 (meas_share, proofs_share, k_blind) = input_share return (meas_share, proofs_share, k_blind) -def prove_rands(self, k_prove: bytes) -> list[F]: +def prove_rands(self, ctx: bytes, k_prove: bytes) -> list[F]: return self.xof.expand_into_vec( self.flp.field, k_prove, - self.domain_separation_tag(USAGE_PROVE_RANDOMNESS), + self.domain_separation_tag(USAGE_PROVE_RANDOMNESS, ctx), byte(self.PROOFS), self.flp.PROVE_RAND_LEN * self.PROOFS, ) -def query_rands(self, verify_key: bytes, nonce: bytes) -> list[F]: +def query_rands( + self, + verify_key: bytes, + ctx: bytes, + nonce: bytes) -> list[F]: return self.xof.expand_into_vec( self.flp.field, verify_key, - self.domain_separation_tag(USAGE_QUERY_RANDOMNESS), + self.domain_separation_tag(USAGE_QUERY_RANDOMNESS, ctx), byte(self.PROOFS) + nonce, self.flp.QUERY_RAND_LEN * self.PROOFS, ) def joint_rand_part( self, + ctx: bytes, agg_id: int, k_blind: bytes, meas_share: list[F], nonce: bytes) -> bytes: return self.xof.derive_seed( k_blind, - self.domain_separation_tag(USAGE_JOINT_RAND_PART), + self.domain_separation_tag(USAGE_JOINT_RAND_PART, ctx), byte(agg_id) + nonce + self.flp.field.encode_vec(meas_share), ) -def joint_rand_seed(self, k_joint_rand_parts: list[bytes]) -> bytes: +def joint_rand_seed(self, + ctx: bytes, + k_joint_rand_parts: list[bytes]) -> bytes: """Derive the joint randomness seed from its parts.""" return self.xof.derive_seed( zeros(self.xof.SEED_SIZE), - self.domain_separation_tag(USAGE_JOINT_RAND_SEED), + self.domain_separation_tag(USAGE_JOINT_RAND_SEED, ctx), concat(k_joint_rand_parts), ) -def joint_rands(self, k_joint_rand_seed: bytes) -> list[F]: +def joint_rands(self, + ctx: bytes, + k_joint_rand_seed: bytes) -> list[F]: """Derive the joint randomness from its seed.""" return self.xof.expand_into_vec( self.flp.field, k_joint_rand_seed, - self.domain_separation_tag(USAGE_JOINT_RANDOMNESS), + self.domain_separation_tag(USAGE_JOINT_RANDOMNESS, ctx), byte(self.PROOFS), self.flp.JOINT_RAND_LEN * self.PROOFS, ) @@ -4412,11 +4473,12 @@ follows. ~~~ python def shard( - self, - measurement: tuple[bool, ...], - nonce: bytes, - rand: bytes, - ) -> tuple[Poplar1PublicShare, list[Poplar1InputShare]]: + self, + ctx: bytes, + measurement: tuple[bool, ...], + nonce: bytes, + rand: bytes, +) -> tuple[Poplar1PublicShare, list[Poplar1InputShare]]: if len(nonce) != self.NONCE_SIZE: raise ValueError("incorrect nonce size") if len(rand) != self.RAND_SIZE: @@ -4435,7 +4497,7 @@ def shard( xof = self.xof( k_shard, - self.domain_separation_tag(USAGE_SHARD_RAND), + self.domain_separation_tag(USAGE_SHARD_RAND, ctx), nonce, ) @@ -4469,14 +4531,14 @@ def shard( self.xof.expand_into_vec( self.idpf.field_inner, corr_seed[0], - self.domain_separation_tag(USAGE_CORR_INNER), + self.domain_separation_tag(USAGE_CORR_INNER, ctx), byte(0) + nonce, 3 * (self.idpf.BITS - 1), ), self.xof.expand_into_vec( self.idpf.field_inner, corr_seed[1], - self.domain_separation_tag(USAGE_CORR_INNER), + self.domain_separation_tag(USAGE_CORR_INNER, ctx), byte(1) + nonce, 3 * (self.idpf.BITS - 1), ), @@ -4485,14 +4547,14 @@ def shard( self.xof.expand_into_vec( self.idpf.field_leaf, corr_seed[0], - self.domain_separation_tag(USAGE_CORR_LEAF), + self.domain_separation_tag(USAGE_CORR_LEAF, ctx), byte(0) + nonce, 3, ), self.xof.expand_into_vec( self.idpf.field_leaf, corr_seed[1], - self.domain_separation_tag(USAGE_CORR_LEAF), + self.domain_separation_tag(USAGE_CORR_LEAF, ctx), byte(1) + nonce, 3, ), @@ -4548,6 +4610,7 @@ but this does allow uniqueness to be determined more efficiently. def prep_init( self, verify_key: bytes, + ctx: bytes, agg_id: int, agg_param: Poplar1AggParam, nonce: bytes, @@ -4574,7 +4637,7 @@ def prep_init( if level < self.idpf.BITS - 1: corr_xof = self.xof( corr_seed, - self.domain_separation_tag(USAGE_CORR_INNER), + self.domain_separation_tag(USAGE_CORR_INNER, ctx), byte(agg_id) + nonce, ) # Fast-forward the XOF state to the current level. @@ -4582,7 +4645,7 @@ def prep_init( else: corr_xof = self.xof( corr_seed, - self.domain_separation_tag(USAGE_CORR_LEAF), + self.domain_separation_tag(USAGE_CORR_LEAF, ctx), byte(agg_id) + nonce, ) (a_share, b_share, c_share) = corr_xof.next_vec(field, 3) @@ -4598,7 +4661,7 @@ def prep_init( # called the "masked input values" [BBCGGI21, Appendix C.4]. verify_rand_xof = self.xof( verify_key, - self.domain_separation_tag(USAGE_VERIFY_RAND), + self.domain_separation_tag(USAGE_VERIFY_RAND, ctx), nonce + to_be_bytes(level, 2), ) verify_rand = cast( @@ -4626,10 +4689,11 @@ def prep_init( ) def prep_next( - self, - prep_state: Poplar1PrepState, - prep_msg: Optional[FieldVec] - ) -> tuple[Poplar1PrepState, FieldVec] | FieldVec: + self, + _ctx: bytes, + prep_state: Poplar1PrepState, + prep_msg: Optional[FieldVec] +) -> tuple[Poplar1PrepState, FieldVec] | FieldVec: prev_sketch = cast(list[Field], prep_msg) (step, level, prep_mem) = prep_state @@ -4671,6 +4735,7 @@ def prep_next( def prep_shares_to_prep( self, + _ctx: bytes, agg_param: Poplar1AggParam, prep_shares: list[FieldVec]) -> Optional[FieldVec]: if len(prep_shares) != 2: @@ -5592,7 +5657,7 @@ may consider involving more than two Aggregators as described for example in {{star-topo}}. Note however that some schemes are not compatible with this mode of operation, such as Poplar1. -## Defense-in-Depth Measures +## Defense-in-Depth Measures {#deep} Prio3 and Poplar1 are designed to resist some attacks that fall outside the main threat model for VDAFs. @@ -5612,7 +5677,9 @@ For example: {{security-multiproof}}.) The joint randomness derivation is bound to both the field (via the algorithm ID) and the number of proofs, thereby ensuring that joint randomness derived for weak parameters is not reused for stronger - parameters. + parameters. In addition, the joint randomness is bound to the application + context, meaning any work the attacker does to attack some application is + not useful for other applications that use the same parameters. There are also some important limitations to be aware of. For example, Prio3 provides domain separation between families of circuits, but does not provide diff --git a/poc/tests/test_daf.py b/poc/tests/test_daf.py index fdbe29ab..c52216d1 100644 --- a/poc/tests/test_daf.py +++ b/poc/tests/test_daf.py @@ -27,8 +27,8 @@ class TestDaf( def shard( self, - measurement: int, _ctx: bytes, + measurement: int, nonce: bytes, rand: bytes) -> tuple[None, list[Field128]]: if len(nonce) != self.NONCE_SIZE: @@ -57,9 +57,9 @@ def is_valid( def prep( self, + _ctx: bytes, _agg_id: int, _agg_param: None, - _ctx: bytes, nonce: bytes, _public_share: None, input_share: Field128) -> Field128: @@ -114,8 +114,8 @@ def run_daf_test( # Run the DAF on the set of measurements. nonces = [gen_rand(daf.NONCE_SIZE) for _ in range(len(measurements))] agg_result = run_daf(daf, - agg_param, b'some application', + agg_param, measurements, nonces) self.assertEqual(agg_result, expected_agg_result) diff --git a/poc/vdaf_poc/daf.py b/poc/vdaf_poc/daf.py index be229638..8c6165a8 100644 --- a/poc/vdaf_poc/daf.py +++ b/poc/vdaf_poc/daf.py @@ -54,8 +54,8 @@ class Daf( @abstractmethod def shard( self, - measurement: Measurement, ctx: bytes, + measurement: Measurement, nonce: bytes, rand: bytes) -> tuple[PublicShare, list[InputShare]]: """ @@ -83,19 +83,19 @@ def is_valid( @abstractmethod def prep( self, + ctx: bytes, agg_id: int, agg_param: AggParam, - ctx: bytes, nonce: bytes, public_share: PublicShare, input_share: InputShare) -> OutShare: """ Prepare an input share for aggregation. This algorithm takes in the public share and one of the input shares generated by the Client. It - also takes the Aggregator's ID (a unique integer in range `[0, SHARES)` - corresponding to the index of `input_share` in the Client's output), - and an aggregation parameter and returns the corresponding output - share. + also takes in the application context, the Aggregator's ID (a unique + integer in range `[0, SHARES)` corresponding to the index of + `input_share` in the Client's output), and an aggregation parameter and + returns the corresponding output share. Pre-conditions: @@ -143,8 +143,8 @@ def run_daf( AggShare, AggResult, ], - agg_param: AggParam, ctx: bytes, + agg_param: AggParam, measurements: list[Measurement], nonces: list[bytes]) -> AggResult: """ @@ -172,12 +172,12 @@ def run_daf( # distributes them among the Aggregators. rand = gen_rand(daf.RAND_SIZE) (public_share, input_shares) = \ - daf.shard(measurement, ctx, nonce, rand) + daf.shard(ctx, measurement, nonce, rand) # Each Aggregator prepares its input share for aggregation. for j in range(daf.SHARES): out_shares[j].append( - daf.prep(j, agg_param, ctx, nonce, + daf.prep(ctx, j, agg_param, nonce, public_share, input_shares[j])) # Each Aggregator aggregates its output shares into an aggregate diff --git a/poc/vdaf_poc/vdaf_prio3.py b/poc/vdaf_poc/vdaf_prio3.py index bdd8b5c8..8dc15ee2 100644 --- a/poc/vdaf_poc/vdaf_prio3.py +++ b/poc/vdaf_poc/vdaf_prio3.py @@ -351,14 +351,15 @@ def shard_with_joint_rand( leader_meas_share = vec_sub(leader_meas_share, helper_meas_share) k_joint_rand_parts.append(self.joint_rand_part( - ctx, j + 1, k_helper_blinds[j], helper_meas_share, nonce)) + ctx, j + 1, k_helper_blinds[j], + helper_meas_share, nonce)) k_joint_rand_parts.insert(0, self.joint_rand_part( ctx, 0, k_leader_blind, leader_meas_share, nonce)) # Generate the proof and shard it into proof shares. prove_rands = self.prove_rands(ctx, k_prove) - joint_rands = self.joint_rands(ctx, - self.joint_rand_seed(ctx, k_joint_rand_parts)) + joint_rands = self.joint_rands( + ctx, self.joint_rand_seed(ctx, k_joint_rand_parts)) leader_proofs_share = [] for _ in range(self.PROOFS): prove_rand, prove_rands = front(