Skip to content

Commit

Permalink
poc: Add application context string (*)
Browse files Browse the repository at this point in the history
WIP: We still need to propagate the changes through the draft and update
security considerations, in particular the line about weak parameters
being exploitable across tasks.

Add an application context string, denoted `ctx`, to sharding and
preparation APIs for DAFs and VDAFs. The intent is to use this for
defense-in-depth: by binding execution of a VDAF to its application
(concretely, in DAP we would set `ctx` to the task ID), we hope that an
attack on that application does not translate directly to an attack on
another application using the same parameters.

This change is motivated by the possibility of an offline attack against
the robustness of some Prio3 variants. The attack is as follows: fix an
invalid measurement, then search for measurement shares that derive
joint randomness that causes the circuit to accept. Once these shares
are found, they can be used to corrupt any batch that uses the same
Prio3 parameters.

To reduce the blast radius of this attack, we can bind `ctx` to the
joint randomness derivation, thereby ensuring the offline work done by
the attacker is only useful to break a single application.

This change is slightly more general: wherever we compute a domain
separation tag in either Prio3 or Poplar1, we append `ctx`. This is a
more invasive than necessary, but it's also more conservative.

Finally, we modify XofTurboShake128 and XofFixedKeyAes128 by changing
the length prefix for the domain separation tag from a single byte to
two bytes, thereby accommodating larger application context strings.
  • Loading branch information
cjpatton committed Sep 19, 2024
1 parent 4abca34 commit ce8f458
Show file tree
Hide file tree
Showing 27 changed files with 480 additions and 381 deletions.
12 changes: 12 additions & 0 deletions poc/gen_test_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,39 +90,46 @@ def gen_test_vec_for_xof(cls: type[Xof]) -> None:
if __name__ == '__main__':
from vdaf_poc import idpf_bbcggi21, vdaf_poplar1, vdaf_prio3, xof

ctx = b'some application'

# Prio3 variants
gen_test_vec_for_vdaf(
TEST_VECTOR_PATH,
vdaf_prio3.Prio3Count(2),
None,
ctx,
[1],
0,
)
gen_test_vec_for_vdaf(
TEST_VECTOR_PATH,
vdaf_prio3.Prio3Count(3),
None,
ctx,
[1],
1,
)
gen_test_vec_for_vdaf(
TEST_VECTOR_PATH,
vdaf_prio3.Prio3Sum(2, 255),
None,
ctx,
[100],
0,
)
gen_test_vec_for_vdaf(
TEST_VECTOR_PATH,
vdaf_prio3.Prio3Sum(3, 255),
None,
ctx,
[100],
1,
)
gen_test_vec_for_vdaf(
TEST_VECTOR_PATH,
vdaf_prio3.Prio3SumVec(2, 10, 8, 9),
None,
ctx,
[
list(range(10)),
[1] * 10,
Expand All @@ -134,6 +141,7 @@ def gen_test_vec_for_xof(cls: type[Xof]) -> None:
TEST_VECTOR_PATH,
vdaf_prio3.Prio3SumVec(3, 3, 16, 7),
None,
ctx,
[
[10000, 32000, 9],
[19342, 19615, 3061],
Expand All @@ -145,20 +153,23 @@ def gen_test_vec_for_xof(cls: type[Xof]) -> None:
TEST_VECTOR_PATH,
vdaf_prio3.Prio3Histogram(2, 4, 2),
None,
ctx,
[2],
0,
)
gen_test_vec_for_vdaf(
TEST_VECTOR_PATH,
vdaf_prio3.Prio3Histogram(3, 11, 3),
None,
ctx,
[2],
1,
)
gen_test_vec_for_vdaf(
TEST_VECTOR_PATH,
vdaf_prio3.Prio3MultihotCountVec(2, 4, 2, 2),
None,
ctx,
[[0, 1, 1, 0]],
0,
)
Expand All @@ -175,6 +186,7 @@ def gen_test_vec_for_xof(cls: type[Xof]) -> None:
TEST_VECTOR_PATH,
vdaf_poplar1.Poplar1(4),
(test_level, prefixes),
ctx,
[0b1101],
test_level,
)
Expand Down
3 changes: 3 additions & 0 deletions poc/tests/test_daf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class TestDaf(
def shard(
self,
measurement: int,
_ctx: bytes,
nonce: bytes,
rand: bytes) -> tuple[None, list[Field128]]:
if len(nonce) != self.NONCE_SIZE:
Expand Down Expand Up @@ -58,6 +59,7 @@ def prep(
self,
_agg_id: int,
_agg_param: None,
_ctx: bytes,
nonce: bytes,
_public_share: None,
input_share: Field128) -> Field128:
Expand Down Expand Up @@ -113,6 +115,7 @@ def run_daf_test(
nonces = [gen_rand(daf.NONCE_SIZE) for _ in range(len(measurements))]
agg_result = run_daf(daf,
agg_param,
b'some application',
measurements,
nonces)
self.assertEqual(agg_result, expected_agg_result)
Expand Down
15 changes: 15 additions & 0 deletions poc/tests/test_vdaf_ping_pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, num_rounds: int) -> None:
# `Vdaf`

def shard(self,
_ctx: bytes,
measurement: int,
_nonce: bytes,
_rand: bytes) -> tuple[str, list[int]]:
Expand All @@ -52,6 +53,7 @@ def is_valid(self,

def prep_init(self,
_verify_key: bytes,
_ctx: bytes,
_agg_id: int,
_agg_param: int,
_nonce: bytes,
Expand All @@ -66,6 +68,7 @@ def prep_init(self,
)

def prep_shares_to_prep(self,
_ctx: bytes,
_agg_param: int,
prep_shares: list[str]) -> str:
for prep_share in prep_shares[1:]:
Expand All @@ -74,6 +77,7 @@ def prep_shares_to_prep(self,
return prep_shares[0]

def prep_next(self,
_ctx: bytes,
prep_state: tuple[int, int],
prep_msg: str) -> Union[tuple[tuple[int, int], str], int]:
(current_round, out_share) = prep_state
Expand Down Expand Up @@ -159,9 +163,11 @@ def test_one_round(self) -> None:
verify_key = b''

measurement = 1337
ctx = b'some context'
nonce = b''
rand = b''
(public_share, input_shares) = vdaf.shard(
ctx,
measurement,
nonce,
rand,
Expand All @@ -170,6 +176,7 @@ def test_one_round(self) -> None:
agg_param = 23
(leader_state, msg) = vdaf.ping_pong_leader_init(
verify_key,
ctx,
vdaf.encode_agg_param(agg_param),
nonce,
vdaf.test_vec_encode_public_share(public_share),
Expand All @@ -179,6 +186,7 @@ def test_one_round(self) -> None:

(helper_state, msg) = vdaf.ping_pong_helper_init(
verify_key,
ctx,
vdaf.encode_agg_param(agg_param),
nonce,
vdaf.test_vec_encode_public_share(public_share),
Expand All @@ -188,6 +196,7 @@ def test_one_round(self) -> None:
self.assertEqual(helper_state, Finished(measurement))

(leader_state, msg) = vdaf.ping_pong_leader_continued(
ctx,
vdaf.encode_agg_param(agg_param),
leader_state,
cast(bytes, msg),
Expand All @@ -199,6 +208,7 @@ def test_multi_round(self) -> None:
"""Test the ping pong flow with multiple rounds."""
verify_key = b''
measurement = 1337
ctx = b'some application'
nonce = b''
rand = b''
agg_param = 23
Expand All @@ -209,13 +219,15 @@ def test_multi_round(self) -> None:
vdaf = PingPongTester(num_rounds)

(public_share, input_shares) = vdaf.shard(
ctx,
measurement,
nonce,
rand,
)

(leader_state, msg) = vdaf.ping_pong_leader_init(
verify_key,
ctx,
vdaf.encode_agg_param(agg_param),
nonce,
vdaf.test_vec_encode_public_share(public_share),
Expand All @@ -227,6 +239,7 @@ def test_multi_round(self) -> None:
if step == 0:
(helper_state, msg) = vdaf.ping_pong_helper_init(
verify_key,
ctx,
vdaf.encode_agg_param(agg_param),
nonce,
vdaf.test_vec_encode_public_share(public_share),
Expand All @@ -236,13 +249,15 @@ def test_multi_round(self) -> None:
else:
(helper_state, msg) = vdaf.ping_pong_helper_continued(
vdaf.encode_agg_param(agg_param),
ctx,
helper_state,
cast(bytes, msg),
)

if isinstance(leader_state, Continued):
(leader_state, msg) = vdaf.ping_pong_leader_continued(
vdaf.encode_agg_param(agg_param),
ctx,
leader_state,
cast(bytes, msg),
)
Expand Down
12 changes: 6 additions & 6 deletions poc/tests/test_xof.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def test_rejection_sampling(self) -> None:
# https://github.com/divergentdave/vdaf-rejection-sampling-search
expanded_vec = XofTurboShake128.expand_into_vec(
Field64,
bytes([0x86, 0x1e, 0x9e, 0x8e, 0x49, 0x44, 0xa2, 0x9f, 0xa4, 0x07,
0x43, 0x23, 0xaf, 0x39, 0xaa, 0xcf, 0xeb, 0xf5, 0xc8, 0x88,
0xa9, 0xe5, 0x5f, 0x0f, 0x9a, 0x9e, 0x4c, 0x70, 0x7d, 0xcd,
0x95, 0x55]),
bytes([0x44, 0x34, 0x1d, 0xc5, 0x2d, 0x71, 0xa2, 0xff, 0x2e, 0x4c,
0x30, 0x5e, 0x93, 0x35, 0xda, 0x9b, 0x19, 0xaf, 0xc6, 0x8e,
0x10, 0xb8, 0xb5, 0x43, 0x69, 0x0d, 0xad, 0x9d, 0x3b, 0xbb,
0x46, 0xba]),
b'', # domain separation tag
b'', # binder
18262,
13883,
)
assert expanded_vec[-1] == Field64(6866317754138817667)
assert expanded_vec[-1] == Field64(4857131209231097247)

def test_turboshake128(self) -> None:
test_xof(XofTurboShake128, Field128, 23)
Expand Down
7 changes: 5 additions & 2 deletions poc/vdaf_poc/daf.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Daf(
def shard(
self,
measurement: Measurement,
ctx: bytes,
nonce: bytes,
rand: bytes) -> tuple[PublicShare, list[InputShare]]:
"""
Expand Down Expand Up @@ -84,6 +85,7 @@ def prep(
self,
agg_id: int,
agg_param: AggParam,
ctx: bytes,
nonce: bytes,
public_share: PublicShare,
input_share: InputShare) -> OutShare:
Expand Down Expand Up @@ -142,6 +144,7 @@ def run_daf(
AggResult,
],
agg_param: AggParam,
ctx: bytes,
measurements: list[Measurement],
nonces: list[bytes]) -> AggResult:
"""
Expand Down Expand Up @@ -169,12 +172,12 @@ def run_daf(
# distributes them among the Aggregators.
rand = gen_rand(daf.RAND_SIZE)
(public_share, input_shares) = \
daf.shard(measurement, nonce, rand)
daf.shard(measurement, ctx, nonce, rand)

# Each Aggregator prepares its input share for aggregation.
for j in range(daf.SHARES):
out_shares[j].append(
daf.prep(j, agg_param, nonce,
daf.prep(j, agg_param, ctx, nonce,
public_share, input_shares[j]))

# Each Aggregator aggregates its output shares into an aggregate
Expand Down
18 changes: 12 additions & 6 deletions poc/vdaf_poc/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def run_vdaf_test(
agg_result = run_vdaf(vdaf,
verify_key,
agg_param,
b"some application context",
nonces,
measurements)
self.assertEqual(agg_result, expected_agg_result)
Expand All @@ -85,6 +86,7 @@ class VdafTestVectorDict(Generic[Measurement, AggParam, AggResult], TypedDict):
shares: int
verify_key: str
agg_param: AggParam
ctx: str
prep: list[VdafPrepTestVectorDict[Measurement]]
agg_shares: list[str]
agg_result: Optional[AggResult]
Expand All @@ -105,6 +107,7 @@ def gen_test_vec_for_vdaf(
PrepMessage,
],
agg_param: AggParam,
ctx: bytes,
measurements: list[Measurement],
test_vec_instance: int,
print_test_vec: bool = True) -> AggResult:
Expand All @@ -120,6 +123,7 @@ def gen_test_vec_for_vdaf(
'shares': vdaf.SHARES,
'verify_key': verify_key.hex(),
'agg_param': agg_param,
'ctx': ctx.hex(),
'prep': [],
'agg_shares': [],
'agg_result': None, # set below
Expand All @@ -135,7 +139,7 @@ def gen_test_vec_for_vdaf(
# Each Client shards its measurement into input shares.
rand = test_vec_gen_rand(vdaf.RAND_SIZE)
(public_share, input_shares) = \
vdaf.shard(measurement, nonce, rand)
vdaf.shard(ctx, measurement, nonce, rand)

prep_test_vec: VdafPrepTestVectorDict[Measurement] = {
'measurement': measurement,
Expand All @@ -157,7 +161,7 @@ def gen_test_vec_for_vdaf(
prep_states = []
outbound_prep_shares = []
for j in range(vdaf.SHARES):
(state, share) = vdaf.prep_init(verify_key, j,
(state, share) = vdaf.prep_init(verify_key, ctx, j,
agg_param,
nonce,
public_share,
Expand All @@ -171,14 +175,15 @@ def gen_test_vec_for_vdaf(

# Aggregators recover their output shares.
for i in range(vdaf.ROUNDS - 1):
prep_msg = vdaf.prep_shares_to_prep(agg_param,
prep_msg = vdaf.prep_shares_to_prep(ctx,
agg_param,
outbound_prep_shares)
prep_test_vec['prep_messages'].append(
vdaf.test_vec_encode_prep_msg(prep_msg).hex())

outbound_prep_shares = []
for j in range(vdaf.SHARES):
out = vdaf.prep_next(prep_states[j], prep_msg)
out = vdaf.prep_next(ctx, prep_states[j], prep_msg)
assert isinstance(out, tuple)
(prep_states[j], prep_share) = out
outbound_prep_shares.append(prep_share)
Expand All @@ -190,14 +195,15 @@ def gen_test_vec_for_vdaf(

# The final outputs of the prepare phase are the output
# shares.
prep_msg = vdaf.prep_shares_to_prep(agg_param,
prep_msg = vdaf.prep_shares_to_prep(ctx,
agg_param,
outbound_prep_shares)
prep_test_vec['prep_messages'].append(
vdaf.test_vec_encode_prep_msg(prep_msg).hex())

outbound_out_shares = []
for j in range(vdaf.SHARES):
out_share = vdaf.prep_next(prep_states[j], prep_msg)
out_share = vdaf.prep_next(ctx, prep_states[j], prep_msg)
assert not isinstance(out_share, tuple)
outbound_out_shares.append(out_share)

Expand Down
Loading

0 comments on commit ce8f458

Please sign in to comment.