diff --git a/poc/tests/test_daf.py b/poc/tests/test_daf.py index c52216d1..e921e225 100644 --- a/poc/tests/test_daf.py +++ b/poc/tests/test_daf.py @@ -70,11 +70,14 @@ def prep( # share. return input_share - def aggregate( - self, - _agg_param: None, - out_shares: list[Field128]) -> Field128: - return reduce(lambda x, y: x + y, out_shares) + def agg_init(self, _agg_param: None) -> Field128: + return Field128(0) + + def agg_update(self, + _agg_param: None, + agg_share: Field128, + agg_delta: Field128) -> Field128: + return agg_share + agg_delta def unshard( self, diff --git a/poc/tests/test_vdaf_ping_pong.py b/poc/tests/test_vdaf_ping_pong.py index 19841380..bf8968af 100644 --- a/poc/tests/test_vdaf_ping_pong.py +++ b/poc/tests/test_vdaf_ping_pong.py @@ -90,8 +90,14 @@ def prep_next(self, "prep round {}".format(current_round+1), ) - def aggregate(self, _agg_param: int, out_shares: list[int]) -> int: - return sum(out_shares) + def agg_init(self, _agg_param: int) -> int: + return 0 + + def agg_update(self, + _agg_param: int, + agg_share: int, + agg_delta: int) -> int: + return agg_share + agg_delta def unshard(self, agg_param: int, diff --git a/poc/vdaf_poc/daf.py b/poc/vdaf_poc/daf.py index 8c6165a8..6f2bdb87 100644 --- a/poc/vdaf_poc/daf.py +++ b/poc/vdaf_poc/daf.py @@ -105,14 +105,22 @@ def prep( pass @abstractmethod - def aggregate( + def agg_init(self, + agg_param: AggParam) -> AggShare: + """ + Return an empty aggregate share. + """ + pass + + @abstractmethod + def agg_update( self, agg_param: AggParam, - out_shares: list[OutShare]) -> AggShare: + agg_share: AggShare, + agg_delta: AggShare | OutShare) -> AggShare: """ - Merge a list of output shares into an aggregate share, encoded as a - byte string. This is called by an Aggregator after recovering a batch - of output shares. + Accumulate an output share or an aggregate share into an aggregate + share and return the updated aggregate share. """ pass @@ -165,8 +173,9 @@ def run_daf( "measurements and nonces lists have different lengths" ) - out_shares: list[list[OutShare]] - out_shares = [[] for j in range(daf.SHARES)] + agg_shares: list[AggShare] + agg_shares = [daf.agg_init(agg_param) + for _ in range(daf.SHARES)] for (measurement, nonce) in zip(measurements, nonces): # Each Client shards its measurement into input shares and # distributes them among the Aggregators. @@ -174,19 +183,14 @@ def run_daf( (public_share, input_shares) = \ daf.shard(ctx, measurement, nonce, rand) - # Each Aggregator prepares its input share for aggregation. + # Each Aggregator computes its output share from its input + # share and aggregates it. for j in range(daf.SHARES): - out_shares[j].append( - daf.prep(ctx, j, agg_param, nonce, - public_share, input_shares[j])) - - # Each Aggregator aggregates its output shares into an aggregate - # share and sends it to the Collector. - agg_shares = [] - for j in range(daf.SHARES): - agg_share_j = daf.aggregate(agg_param, - out_shares[j]) - agg_shares.append(agg_share_j) + out_share = daf.prep(ctx, j, agg_param, nonce, + public_share, input_shares[j]) + agg_shares[j] = daf.agg_update(agg_param, + agg_shares[j], + out_share) # Collector unshards the aggregate result. num_measurements = len(measurements) diff --git a/poc/vdaf_poc/test_utils.py b/poc/vdaf_poc/test_utils.py index 61207fe4..4fd0ca87 100644 --- a/poc/vdaf_poc/test_utils.py +++ b/poc/vdaf_poc/test_utils.py @@ -221,7 +221,9 @@ def gen_test_vec_for_vdaf( agg_shares = [] for j in range(vdaf.SHARES): out_shares_j = [out[j] for out in out_shares] - agg_share_j = vdaf.aggregate(agg_param, out_shares_j) + agg_share_j = vdaf.agg_init(agg_param) + for out_share in out_shares_j: + agg_share_j = vdaf.agg_update(agg_param, agg_share_j, out_share) agg_shares.append(agg_share_j) # REMOVE ME test_vec['agg_shares'].append( diff --git a/poc/vdaf_poc/vdaf.py b/poc/vdaf_poc/vdaf.py index 1b9f5091..1d8d08e5 100644 --- a/poc/vdaf_poc/vdaf.py +++ b/poc/vdaf_poc/vdaf.py @@ -143,13 +143,22 @@ def prep_shares_to_prep(self, pass @abstractmethod - def aggregate(self, - agg_param: AggParam, - out_shares: list[OutShare]) -> AggShare: + def agg_init(self, + agg_param: AggParam) -> AggShare: + """ + Return an empty aggregate share. + """ + pass + + @abstractmethod + def agg_update( + self, + agg_param: AggParam, + agg_share: AggShare, + agg_delta: AggShare | OutShare) -> AggShare: """ - Merge a list of output shares into an aggregate share, encoded as a byte - string. This is called by an aggregator after recovering a batch of - output shares. + Accumulate an output share or an aggregate share into an aggregate + share and return the updated aggregate share. """ pass @@ -254,7 +263,8 @@ def run_vdaf( "measurements and nonces lists have different lengths" ) - out_shares = [] + agg_shares = [vdaf.agg_init(agg_param) + for _ in range(vdaf.SHARES)] for (nonce, measurement) in zip(nonces, measurements): assert len(nonce) == vdaf.NONCE_SIZE @@ -275,7 +285,7 @@ def run_vdaf( prep_states.append(state) outbound_prep_shares.append(share) - # Aggregators recover their output shares. + # Aggregators completes preparation. for i in range(vdaf.ROUNDS - 1): prep_msg = vdaf.prep_shares_to_prep(ctx, agg_param, @@ -288,28 +298,17 @@ def run_vdaf( (prep_states[j], prep_share) = out outbound_prep_shares.append(prep_share) - # The final outputs of the prepare phase are the output - # shares. prep_msg = vdaf.prep_shares_to_prep(ctx, agg_param, outbound_prep_shares) - outbound_out_shares = [] + # Each Aggregator computes and aggregates its output share. for j in range(vdaf.SHARES): out_share = vdaf.prep_next(ctx, prep_states[j], prep_msg) assert not isinstance(out_share, tuple) - outbound_out_shares.append(out_share) - - out_shares.append(outbound_out_shares) - - # Each Aggregator aggregates its output shares into an - # aggregate share. In a distributed VDAF computation, the - # aggregate shares are sent over the network. - agg_shares = [] - for j in range(vdaf.SHARES): - out_shares_j = [out[j] for out in out_shares] - agg_share_j = vdaf.aggregate(agg_param, out_shares_j) - agg_shares.append(agg_share_j) + agg_shares[j] = vdaf.agg_update(agg_param, + agg_shares[j], + out_share) # Collector unshards the aggregate. num_measurements = len(measurements) diff --git a/poc/vdaf_poc/vdaf_poplar1.py b/poc/vdaf_poc/vdaf_poplar1.py index b719a0eb..f1e7a819 100644 --- a/poc/vdaf_poc/vdaf_poplar1.py +++ b/poc/vdaf_poc/vdaf_poplar1.py @@ -376,21 +376,24 @@ def prep_shares_to_prep( else: raise ValueError('incorrect sketch length') - # NOTE: This method is excerpted in the document, de-indented, as - # figure {{poplar1-out2agg}}. Its width should be limited to 69 - # columns after de-indenting, or 73 columns before de-indenting, to - # avoid warnings from xml2rfc. + # NOTE: Methods `agg_init()` and `agg_update()` are excerpted in the + # document, de-indented, as figure {{poplar1-out2agg}}. The width + # should be limited to 69 columns after de-indenting, or 73 columns + # before de-indenting, to avoid warnings from xml2rfc. # =================================================================== - def aggregate( - self, - agg_param: Poplar1AggParam, - out_shares: list[FieldVec]) -> FieldVec: + def agg_init(self, agg_param: Poplar1AggParam) -> FieldVec: (level, prefixes) = agg_param field = self.idpf.current_field(level) - agg_share = cast(list[Field], field.zeros(len(prefixes))) - for out_share in out_shares: - agg_share = vec_add(agg_share, cast(list[Field], out_share)) - return cast(FieldVec, agg_share) + return field.zeros(len(prefixes)) + + def agg_update(self, + agg_param: Poplar1AggParam, + agg_share: FieldVec, + agg_delta: FieldVec) -> FieldVec: + (level, _prefixes) = agg_param + a = cast(list[Field], agg_share) + d = cast(list[Field], agg_delta) + return cast(FieldVec, vec_add(a, d)) # NOTE: This method is excerpted in the document, de-indented, as # figure {{poplar1-agg-output}}. Its width should be limited to 69 diff --git a/poc/vdaf_poc/vdaf_prio3.py b/poc/vdaf_poc/vdaf_prio3.py index 28bbbb3a..e37c0a7f 100644 --- a/poc/vdaf_poc/vdaf_prio3.py +++ b/poc/vdaf_poc/vdaf_prio3.py @@ -232,19 +232,19 @@ def prep_shares_to_prep( joint_rand_seed = self.joint_rand_seed(ctx, joint_rand_parts) return joint_rand_seed - # NOTE: This method is excerpted in the document, de-indented, as - # figure {{prio3-out2agg}}. Its width should be limited to 69 columns - # after de-indenting, or 73 columns before de-indenting, to avoid - # warnings from xml2rfc. + # NOTE: Methods `agg_init()` and `agg_update()` are excerpted in the + # document, de-indented, as figure {{prio3-out2agg}}. The width + # should be limited to 69 columns after de-indenting, or 73 columns + # before de-indenting, to avoid warnings from xml2rfc. # =================================================================== - def aggregate( - self, - _agg_param: None, - out_shares: list[list[F]]) -> list[F]: - agg_share = self.flp.field.zeros(self.flp.OUTPUT_LEN) - for out_share in out_shares: - agg_share = vec_add(agg_share, out_share) - return agg_share + def agg_init(self, _agg_param: None) -> list[F]: + return self.flp.field.zeros(self.flp.OUTPUT_LEN) + + def agg_update(self, + _agg_param: None, + agg_share: list[F], + agg_delta: list[F]) -> list[F]: + return vec_add(agg_share, agg_delta) # NOTE: This method is excerpted in the document, de-indented, as # figure {{prio3-agg-output}}. Its width should be limited to 69