Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up after full pass #499

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 93 additions & 93 deletions draft-irtf-cfrg-vdaf.md
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,22 @@ derive per-report randomness for verification of the computation. See

## Preparation {#sec-daf-prepare}

~~~
Aggregator 0 Aggregator 1 Aggregator SHARES-1
============ ============ ===================

input_share_0 input_share_1 input_share_[SHARES-1]
| | ... |
V V V
+-----------+ +-----------+ +-----------+
| prep | | prep | | prep |
+-----------+ +-----------+ +-----------+
| | ... |
V V V
out_share_0 out_share_1 out_share_[SHARES-1]
~~~
{: #daf-prep-flow title="Illustration of preparation."}

Once an Aggregator has received the public share and its input share, the next
step is to prepare the input share for aggregation. This is accomplished using
the preparation algorithm:
Expand Down Expand Up @@ -1202,39 +1218,27 @@ def run_daf(

Pre-conditions:

- `type(agg_param) == daf.AggParam`
- `type(measurement) == daf.Measurement` for each
`measurement` in `measurements`
- `len(nonce) == daf.NONCE_SIZE` for each `nonce` in `nonces`
- `len(nonces) == len(measurements)`
- `all(len(nonce) == daf.NONCE_SIZE for nonce in nonces)`
"""
if any(len(nonce) != daf.NONCE_SIZE for nonce in nonces):
raise ValueError("incorrect nonce size")
if len(nonces) != len(measurements):
raise ValueError(
"measurements and nonces lists have different lengths"
)

agg_shares: list[AggShare]
agg_shares = [daf.agg_init(agg_param)
for _ in range(daf.SHARES)]
for (measurement, nonce) in zip(measurements, nonces):
# Each Client shards its measurement into input shares and
# distributes them among the Aggregators.
# Sharding
rand = gen_rand(daf.RAND_SIZE)
(public_share, input_shares) = \
daf.shard(ctx, measurement, nonce, rand)

# Each Aggregator computes its output share from its input
# share and aggregates it.
# Preparation, aggregation
for j in range(daf.SHARES):
out_share = daf.prep(ctx, j, agg_param, nonce,
public_share, input_shares[j])
agg_shares[j] = daf.agg_update(agg_param,
agg_shares[j],
out_share)

# Collector unshards the aggregate result.
# Unsharding
num_measurements = len(measurements)
agg_result = daf.unshard(agg_param, agg_shares,
num_measurements)
Expand Down Expand Up @@ -1491,27 +1495,17 @@ def run_vdaf(
- `len(nonces) == len(measurements)`
- `all(len(nonce) == vdaf.NONCE_SIZE for nonce in nonces)`
"""

if len(verify_key) != vdaf.VERIFY_KEY_SIZE:
raise ValueError("incorrect verify_key size")
if any(len(nonce) != vdaf.NONCE_SIZE for nonce in nonces):
raise ValueError("incorrect nonce size")
if len(nonces) != len(measurements):
raise ValueError(
"measurements and nonces lists have different lengths"
)

agg_shares = [vdaf.agg_init(agg_param)
for _ in range(vdaf.SHARES)]
for (nonce, measurement) in zip(nonces, measurements):
assert len(nonce) == vdaf.NONCE_SIZE

# Each Client shards its measurement into input shares.
# Sharding
rand = gen_rand(vdaf.RAND_SIZE)
(public_share, input_shares) = \
vdaf.shard(ctx, measurement, nonce, rand)

# Each Aggregator initializes its preparation state.
# Initialize preparation
prep_states = []
outbound_prep_shares = []
for j in range(vdaf.SHARES):
Expand All @@ -1523,7 +1517,7 @@ def run_vdaf(
prep_states.append(state)
outbound_prep_shares.append(share)

# Aggregators complete preparation.
# Complete preparation
for i in range(vdaf.ROUNDS - 1):
prep_msg = vdaf.prep_shares_to_prep(ctx,
agg_param,
Expand All @@ -1540,15 +1534,15 @@ def run_vdaf(
agg_param,
outbound_prep_shares)

# Each Aggregator computes and aggregates its output share.
# Aggregation
for j in range(vdaf.SHARES):
out_share = vdaf.prep_next(ctx, prep_states[j], prep_msg)
assert not isinstance(out_share, tuple)
agg_shares[j] = vdaf.agg_update(agg_param,
agg_shares[j],
out_share)

# Collector unshards the aggregate.
# Unsharding
num_measurements = len(measurements)
agg_result = vdaf.unshard(agg_param, agg_shares,
num_measurements)
Expand Down Expand Up @@ -1591,36 +1585,8 @@ State transitions are made when the state is acted upon by the Aggregator's
local inputs and/or messages sent by its co-Aggregators. The initial state is
`Start`. The terminal states are: `Rejected`, indicating that the report cannot
be processed any further; and `Finished(out_share)`, indicating that the
Aggregator has recovered an output share `out_share`.

~~~ python
class State:
pass

class Start(State):
pass

class Continued(State, Generic[PrepState]):
def __init__(self, prep_state: PrepState, prep_round: int):
self.prep_state = prep_state
self.prep_round = prep_round

def __eq__(self, other: object) -> bool:
return isinstance(other, Continued) and \
self.prep_state == other.prep_state and \
self.prep_round == other.prep_round

class Finished(State, Generic[OutShare]):
def __init__(self, out_share: OutShare):
self.out_share = out_share

def __eq__(self, other: object) -> bool:
return isinstance(other, Finished) and \
self.out_share == other.out_share

class Rejected(State):
pass
~~~
Aggregator has recovered an output share `out_share`. For completeness, we
define these states in {{topo-states}}.

The methods described in this section are defined in terms of opaque byte
strings. A compatible `Vdaf` MUST specify methods for encoding public shares,
Expand Down Expand Up @@ -2013,7 +1979,7 @@ are used to encode measurements in some variants of Prio3
({{prio3-instantiations}}).

~~~ python
def encode_into_bit_vector(
def encode_into_bit_vec(
cls,
val: int,
bits: int) -> list[Self]:
Expand All @@ -2036,7 +2002,7 @@ def encode_into_bit_vector(
encoded.append(cls((val >> l) & 1))
return encoded

def decode_from_bit_vector(cls, vec: list[Self]) -> Self:
def decode_from_bit_vec(cls, vec: list[Self]) -> Self:
"""
Decode the field element from the bit representation, expressed
as a vector of field elements `vec`.
Expand Down Expand Up @@ -2395,7 +2361,7 @@ for various types of aggregation tasks are specified in
{{prio3-instantiations}}. Test vectors for each variant can be found in
{{test-vectors}}.

## Fully Linear Proof (FLP) Systems {#flp}
## Fully Linear Proofs (FLPs) {#flp}

Conceptually, an FLP is a two-party protocol executed by a prover and a
verifier. In actual use, however, the prover's computation is carried out by
Expand Down Expand Up @@ -2604,7 +2570,7 @@ accepts on each proof.
See {{security-multiproof}} for guidance on choosing the field size and number
of proofs.

## Construction {#prio3-construction}
## Specification {#prio3-construction}

This section specifies `Prio3`, an implementation of the `Vdaf` interface
defined in {{vdaf}}. The parameters and types required by the `Vdaf` interface
Expand Down Expand Up @@ -2949,22 +2915,6 @@ def prep_init(
prep_share = (verifiers_share, joint_rand_part)
return (prep_state, prep_share)

def prep_next(
self,
_ctx: bytes,
prep_state: Prio3PrepState[F],
prep_msg: Optional[bytes]
) -> tuple[Prio3PrepState[F], Prio3PrepShare[F]] | list[F]:
joint_rand_seed = prep_msg
(out_share, corrected_joint_rand_seed) = prep_state

# If joint randomness was used, check that the value computed by
# the Aggregators matches the value indicated by the Client.
if joint_rand_seed != corrected_joint_rand_seed:
raise ValueError('joint randomness check failed')

return out_share

def prep_shares_to_prep(
self,
ctx: bytes,
Expand Down Expand Up @@ -2993,6 +2943,22 @@ def prep_shares_to_prep(
if self.flp.JOINT_RAND_LEN > 0:
joint_rand_seed = self.joint_rand_seed(ctx, joint_rand_parts)
return joint_rand_seed

def prep_next(
self,
_ctx: bytes,
prep_state: Prio3PrepState[F],
prep_msg: Optional[bytes]
) -> tuple[Prio3PrepState[F], Prio3PrepShare[F]] | list[F]:
joint_rand_seed = prep_msg
(out_share, corrected_joint_rand_seed) = prep_state

# If joint randomness was used, check that the value computed by
# the Aggregators matches the value indicated by the Client.
if joint_rand_seed != corrected_joint_rand_seed:
raise ValueError('joint randomness check failed')

return out_share
~~~

### Validity of Aggregation Parameters
Expand Down Expand Up @@ -3285,7 +3251,7 @@ struct {
} Prio3AggShare;
~~~

## FLP Construction {#flp-bbcggi19}
## FLP Specification {#flp-bbcggi19}

| Parameter | Value |
|:-----------------|:-------------------------------------------------|
Expand Down Expand Up @@ -3899,25 +3865,25 @@ class Sum(Valid[int, int, F]):
out.append(self.GADGETS[0].eval(self.field, [b]))

range_check = self.offset * shares_inv + \
self.field.decode_from_bit_vector(meas[:self.bits]) - \
self.field.decode_from_bit_vector(meas[self.bits:])
self.field.decode_from_bit_vec(meas[:self.bits]) - \
self.field.decode_from_bit_vec(meas[self.bits:])
out.append(range_check)
return out

def encode(self, measurement: int) -> list[F]:
encoded = []
encoded += self.field.encode_into_bit_vector(
encoded += self.field.encode_into_bit_vec(
measurement,
self.bits
)
encoded += self.field.encode_into_bit_vector(
encoded += self.field.encode_into_bit_vec(
measurement + self.offset.int(),
self.bits
)
return encoded

def truncate(self, meas: list[F]) -> list[F]:
return [self.field.decode_from_bit_vector(meas[:self.bits])]
return [self.field.decode_from_bit_vec(meas[:self.bits])]

def decode(self, output: list[F], _num_measurements: int) -> int:
return output[0].int()
Expand Down Expand Up @@ -4033,14 +3999,14 @@ class SumVec(Valid[list[int], list[int], F]):
def encode(self, measurement: list[int]) -> list[F]:
encoded = []
for val in measurement:
encoded += self.field.encode_into_bit_vector(
encoded += self.field.encode_into_bit_vec(
val, self.bits)
return encoded

def truncate(self, meas: list[F]) -> list[F]:
truncated = []
for i in range(self.length):
truncated.append(self.field.decode_from_bit_vector(
truncated.append(self.field.decode_from_bit_vec(
meas[i * self.bits: (i + 1) * self.bits]
))
return truncated
Expand Down Expand Up @@ -4309,7 +4275,7 @@ class MultihotCountVec(Valid[list[int], list[int], F]):
count_vec = meas[:self.length]
weight = sum(count_vec, self.field(0))
weight_reported = \
self.field.decode_from_bit_vector(meas[self.length:])
self.field.decode_from_bit_vec(meas[self.length:])
weight_check = self.offset*shares_inv + weight - \
weight_reported

Expand All @@ -4327,7 +4293,7 @@ class MultihotCountVec(Valid[list[int], list[int], F]):

encoded = []
encoded += count_vec
encoded += self.field.encode_into_bit_vector(
encoded += self.field.encode_into_bit_vec(
(self.offset + weight_reported).int(),
self.bits_for_weight)
return encoded
Expand Down Expand Up @@ -4545,7 +4511,7 @@ fixed length. Further, the padding scheme must be non-ambiguous. For example,
each input could be padded with `b"\x01"` followed by as many `b"\x00"` bytes
as needed.

## Construction {#poplar1-construction}
## Specification {#poplar1-construction}

This section specifies `Poplar1`, an implementation of the `Vdaf` interface
({{vdaf}}). It is defined in terms of the `Idpf` implementation of
Expand Down Expand Up @@ -5224,7 +5190,7 @@ than is strictly needed. In particular, it may be sufficient to convey which
indices from the previous execution will have their children included in the
next. This would help reduce communication overhead.

## IDPF Construction {#idpf-bbcggi21}
## IDPF Specification {#idpf-bbcggi21}

In this section we specify a concrete IDPF suitable for instantiating Poplar1.
The constant and type definitions required by the `Idpf` interface are given in
Expand Down Expand Up @@ -6142,6 +6108,40 @@ class QueryGadget(Gadget[F]):
return wrapped_valid
~~~

# VDAF Preparation State {#topo-states}

This section lists the classes used to define each Aggregator's state during
VDAF preparation ({{vdaf-prep-comm}}).

~~~ python
class State:
pass

class Start(State):
pass

class Continued(State, Generic[PrepState]):
def __init__(self, prep_state: PrepState, prep_round: int):
self.prep_state = prep_state
self.prep_round = prep_round

def __eq__(self, other: object) -> bool:
return isinstance(other, Continued) and \
self.prep_state == other.prep_state and \
self.prep_round == other.prep_round

class Finished(State, Generic[OutShare]):
def __init__(self, out_share: OutShare):
self.out_share = out_share

def __eq__(self, other: object) -> bool:
return isinstance(other, Finished) and \
self.out_share == other.out_share

class Rejected(State):
pass
~~~

# Test Vectors {#test-vectors}

Test vectors for Prio3 ({{prio3}}) and Poplar1 ({{poplar1}}) are available at
Expand Down
4 changes: 2 additions & 2 deletions poc/tests/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def run_field_test(self, cls: type[Field]) -> None:
vals = [i for i in range(15)]
bits = 4
for val in vals:
encoded = cls.encode_into_bit_vector(val, bits)
self.assertTrue(cls.decode_from_bit_vector(
encoded = cls.encode_into_bit_vec(val, bits)
self.assertTrue(cls.decode_from_bit_vec(
encoded).int() == val)

def run_ntt_field_test(self, cls: type[NttField]) -> None:
Expand Down
Loading