Skip to content

Commit

Permalink
Replace aggregate() with streaming API
Browse files Browse the repository at this point in the history
Modify both `Daf` and `Vdaf` by replacing `aggregate()` with two
methods:

* `agg_init()` returns an empty aggregate share
* `agg_update()` updates an aggregate share with an output share or
  another aggregate share
  • Loading branch information
cjpatton committed Oct 7, 2024
1 parent fe5027f commit 1f26b7f
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 74 deletions.
13 changes: 8 additions & 5 deletions poc/tests/test_daf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,14 @@ def prep(
# share.
return input_share

def aggregate(
self,
_agg_param: None,
out_shares: list[Field128]) -> Field128:
return reduce(lambda x, y: x + y, out_shares)
def agg_init(self, _agg_param: None) -> Field128:
return Field128(0)

def agg_update(self,
_agg_param: None,
agg_share: Field128,
agg_delta: Field128) -> Field128:
return agg_share + agg_delta

def unshard(
self,
Expand Down
10 changes: 8 additions & 2 deletions poc/tests/test_vdaf_ping_pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,14 @@ def prep_next(self,
"prep round {}".format(current_round+1),
)

def aggregate(self, _agg_param: int, out_shares: list[int]) -> int:
return sum(out_shares)
def agg_init(self, _agg_param: int) -> int:
return 0

def agg_update(self,
_agg_param: int,
agg_share: int,
agg_delta: int) -> int:
return agg_share + agg_delta

def unshard(self,
agg_param: int,
Expand Down
42 changes: 23 additions & 19 deletions poc/vdaf_poc/daf.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,22 @@ def prep(
pass

@abstractmethod
def aggregate(
def agg_init(self,
agg_param: AggParam) -> AggShare:
"""
Return an empty aggregate share.
"""
pass

@abstractmethod
def agg_update(
self,
agg_param: AggParam,
out_shares: list[OutShare]) -> AggShare:
agg_share: AggShare,
agg_delta: AggShare | OutShare) -> AggShare:
"""
Merge a list of output shares into an aggregate share, encoded as a
byte string. This is called by an Aggregator after recovering a batch
of output shares.
Accumulate an output share or an aggregate share into an aggregate
share and return the updated aggregate share.
"""
pass

Expand Down Expand Up @@ -165,28 +173,24 @@ def run_daf(
"measurements and nonces lists have different lengths"
)

out_shares: list[list[OutShare]]
out_shares = [[] for j in range(daf.SHARES)]
agg_shares: list[AggShare]
agg_shares = [daf.agg_init(agg_param)
for _ in range(daf.SHARES)]
for (measurement, nonce) in zip(measurements, nonces):
# Each Client shards its measurement into input shares and
# distributes them among the Aggregators.
rand = gen_rand(daf.RAND_SIZE)
(public_share, input_shares) = \
daf.shard(ctx, measurement, nonce, rand)

# Each Aggregator prepares its input share for aggregation.
# Each Aggregator computes its output share from its input
# share and aggregates it.
for j in range(daf.SHARES):
out_shares[j].append(
daf.prep(ctx, j, agg_param, nonce,
public_share, input_shares[j]))

# Each Aggregator aggregates its output shares into an aggregate
# share and sends it to the Collector.
agg_shares = []
for j in range(daf.SHARES):
agg_share_j = daf.aggregate(agg_param,
out_shares[j])
agg_shares.append(agg_share_j)
out_share = daf.prep(ctx, j, agg_param, nonce,
public_share, input_shares[j])
agg_shares[j] = daf.agg_update(agg_param,
agg_shares[j],
out_share)

# Collector unshards the aggregate result.
num_measurements = len(measurements)
Expand Down
4 changes: 3 additions & 1 deletion poc/vdaf_poc/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ def gen_test_vec_for_vdaf(
agg_shares = []
for j in range(vdaf.SHARES):
out_shares_j = [out[j] for out in out_shares]
agg_share_j = vdaf.aggregate(agg_param, out_shares_j)
agg_share_j = vdaf.agg_init(agg_param)
for out_share in out_shares_j:
agg_share_j = vdaf.agg_update(agg_param, agg_share_j, out_share)
agg_shares.append(agg_share_j)
# REMOVE ME
test_vec['agg_shares'].append(
Expand Down
45 changes: 22 additions & 23 deletions poc/vdaf_poc/vdaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,22 @@ def prep_shares_to_prep(self,
pass

@abstractmethod
def aggregate(self,
agg_param: AggParam,
out_shares: list[OutShare]) -> AggShare:
def agg_init(self,
agg_param: AggParam) -> AggShare:
"""
Return an empty aggregate share.
"""
pass

@abstractmethod
def agg_update(
self,
agg_param: AggParam,
agg_share: AggShare,
agg_delta: AggShare | OutShare) -> AggShare:
"""
Merge a list of output shares into an aggregate share, encoded as a byte
string. This is called by an aggregator after recovering a batch of
output shares.
Accumulate an output share or an aggregate share into an aggregate
share and return the updated aggregate share.
"""
pass

Expand Down Expand Up @@ -254,7 +263,8 @@ def run_vdaf(
"measurements and nonces lists have different lengths"
)

out_shares = []
agg_shares = [vdaf.agg_init(agg_param)
for _ in range(vdaf.SHARES)]
for (nonce, measurement) in zip(nonces, measurements):
assert len(nonce) == vdaf.NONCE_SIZE

Expand All @@ -275,7 +285,7 @@ def run_vdaf(
prep_states.append(state)
outbound_prep_shares.append(share)

# Aggregators recover their output shares.
# Aggregators completes preparation.
for i in range(vdaf.ROUNDS - 1):
prep_msg = vdaf.prep_shares_to_prep(ctx,
agg_param,
Expand All @@ -288,28 +298,17 @@ def run_vdaf(
(prep_states[j], prep_share) = out
outbound_prep_shares.append(prep_share)

# The final outputs of the prepare phase are the output
# shares.
prep_msg = vdaf.prep_shares_to_prep(ctx,
agg_param,
outbound_prep_shares)

outbound_out_shares = []
# Each Aggregator computes and aggregates its output share.
for j in range(vdaf.SHARES):
out_share = vdaf.prep_next(ctx, prep_states[j], prep_msg)
assert not isinstance(out_share, tuple)
outbound_out_shares.append(out_share)

out_shares.append(outbound_out_shares)

# Each Aggregator aggregates its output shares into an
# aggregate share. In a distributed VDAF computation, the
# aggregate shares are sent over the network.
agg_shares = []
for j in range(vdaf.SHARES):
out_shares_j = [out[j] for out in out_shares]
agg_share_j = vdaf.aggregate(agg_param, out_shares_j)
agg_shares.append(agg_share_j)
agg_shares[j] = vdaf.agg_update(agg_param,
agg_shares[j],
out_share)

# Collector unshards the aggregate.
num_measurements = len(measurements)
Expand Down
27 changes: 15 additions & 12 deletions poc/vdaf_poc/vdaf_poplar1.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,21 +376,24 @@ def prep_shares_to_prep(
else:
raise ValueError('incorrect sketch length')

# NOTE: This method is excerpted in the document, de-indented, as
# figure {{poplar1-out2agg}}. Its width should be limited to 69
# columns after de-indenting, or 73 columns before de-indenting, to
# avoid warnings from xml2rfc.
# NOTE: Methods `agg_init()` and `agg_update()` are excerpted in the
# document, de-indented, as figure {{poplar1-out2agg}}. The width
# should be limited to 69 columns after de-indenting, or 73 columns
# before de-indenting, to avoid warnings from xml2rfc.
# ===================================================================
def aggregate(
self,
agg_param: Poplar1AggParam,
out_shares: list[FieldVec]) -> FieldVec:
def agg_init(self, agg_param: Poplar1AggParam) -> FieldVec:
(level, prefixes) = agg_param
field = self.idpf.current_field(level)
agg_share = cast(list[Field], field.zeros(len(prefixes)))
for out_share in out_shares:
agg_share = vec_add(agg_share, cast(list[Field], out_share))
return cast(FieldVec, agg_share)
return field.zeros(len(prefixes))

def agg_update(self,
agg_param: Poplar1AggParam,
agg_share: FieldVec,
agg_delta: FieldVec) -> FieldVec:
(level, _prefixes) = agg_param
a = cast(list[Field], agg_share)
d = cast(list[Field], agg_delta)
return cast(FieldVec, vec_add(a, d))

# NOTE: This method is excerpted in the document, de-indented, as
# figure {{poplar1-agg-output}}. Its width should be limited to 69
Expand Down
24 changes: 12 additions & 12 deletions poc/vdaf_poc/vdaf_prio3.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,19 +232,19 @@ def prep_shares_to_prep(
joint_rand_seed = self.joint_rand_seed(ctx, joint_rand_parts)
return joint_rand_seed

# NOTE: This method is excerpted in the document, de-indented, as
# figure {{prio3-out2agg}}. Its width should be limited to 69 columns
# after de-indenting, or 73 columns before de-indenting, to avoid
# warnings from xml2rfc.
# NOTE: Methods `agg_init()` and `agg_update()` are excerpted in the
# document, de-indented, as figure {{prio3-out2agg}}. The width
# should be limited to 69 columns after de-indenting, or 73 columns
# before de-indenting, to avoid warnings from xml2rfc.
# ===================================================================
def aggregate(
self,
_agg_param: None,
out_shares: list[list[F]]) -> list[F]:
agg_share = self.flp.field.zeros(self.flp.OUTPUT_LEN)
for out_share in out_shares:
agg_share = vec_add(agg_share, out_share)
return agg_share
def agg_init(self, _agg_param: None) -> list[F]:
return self.flp.field.zeros(self.flp.OUTPUT_LEN)

def agg_update(self,
_agg_param: None,
agg_share: list[F],
agg_delta: list[F]) -> list[F]:
return vec_add(agg_share, agg_delta)

# NOTE: This method is excerpted in the document, de-indented, as
# figure {{prio3-agg-output}}. Its width should be limited to 69
Expand Down

0 comments on commit 1f26b7f

Please sign in to comment.