Skip to content

Commit

Permalink
Change measurement type of Prio3MultihotCountVec
Browse files Browse the repository at this point in the history
  • Loading branch information
divergentdave committed Oct 17, 2024
1 parent dcadeaa commit ad3b6dd
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 73 deletions.
26 changes: 13 additions & 13 deletions draft-irtf-cfrg-vdaf.md
Original file line number Diff line number Diff line change
Expand Up @@ -4156,18 +4156,18 @@ class Histogram(Valid[int, list[int], F]):
| `PROOFS` | `1` |
{: title="Parameters for Prio3MultihotCountVec."}

For this instance of Prio3, each measurement is a vector of ones and zeros,
where the number of ones is bounded. This provides a functionality similar to
Prio3Histogram except that more than one entry (or none at all) may be
non-zero. This allows Prio3MultihotCountVec to be composed with a randomized
response mechanism, like {{EPK14}}, for providing differential privacy. (For
example, each Client would set each entry to one with some small probability.)
For this instance of Prio3, each measurement is a vector of true and false
values, where the number of true values is bounded. This provides a
functionality similar to Prio3Histogram except that more than one entry (or none
at all) may be non-zero. This allows Prio3MultihotCountVec to be composed with a
randomized response mechanism, like {{EPK14}}, for providing differential
privacy. (For example, each Client would set each entry with some small
probability.)

The validity circuit is denoted `MultihotCountVec` and has three parameters:
`length`, the number of entries in the count vector; `max_weight`, the
maximum number of non-zero entries (i.e., the weight must be at most
`max_weight`); and `chunk_length`, used the same way as in {{prio3sumvec}} and
{{prio3histogram}}.
`length`, the number of entries in the count vector; `max_weight`, the maximum
number of true entries (i.e., the weight must be at most `max_weight`); and
`chunk_length`, used the same way as in {{prio3sumvec}} and {{prio3histogram}}.

Validation works as follows. Let

Expand All @@ -4185,7 +4185,7 @@ zero if and only if the reported weight is equal to the true weight. The
complete circuit is defined below.

~~~ python
class MultihotCountVec(Valid[list[int], list[int], F]):
class MultihotCountVec(Valid[list[bool], list[int], F]):
EVAL_OUTPUT_LEN = 2
field: type[F]

Expand Down Expand Up @@ -4232,12 +4232,12 @@ class MultihotCountVec(Valid[list[int], list[int], F]):
self.OUTPUT_LEN = self.length
self.JOINT_RAND_LEN = self.GADGET_CALLS[0]

def encode(self, measurement: list[int]) -> list[F]:
def encode(self, measurement: list[bool]) -> list[F]:
if len(measurement) != self.length:
raise ValueError('invalid Client measurement length')

# The first part is the vector of counters.
count_vec = list(map(self.field, measurement))
count_vec = [self.field(int(x)) for x in measurement]

# The second part is the reported weight.
weight_reported = sum(count_vec, self.field(0))
Expand Down
14 changes: 7 additions & 7 deletions poc/gen_test_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,15 @@ def gen_test_vec_for_xof(cls: type[Xof]) -> None:
vdaf_prio3.Prio3MultihotCountVec(2, 4, 2, 2),
None,
ctx,
[[0, 1, 1, 0]],
[[False, True, True, False]],
0,
)
gen_test_vec_for_vdaf(
vdaf_test_vec_path,
vdaf_prio3.Prio3MultihotCountVec(4, 10, 2, 3),
None,
ctx,
[[0, 1, 0, 0, 0, 0, 0, 0, 0, 1]],
[[False, True, False, False, False, False, False, False, False, True]],
1,
)
gen_test_vec_for_vdaf(
Expand All @@ -229,11 +229,11 @@ def gen_test_vec_for_xof(cls: type[Xof]) -> None:
None,
ctx,
[
[0, 1, 1, 0],
[0, 0, 1, 0],
[0, 0, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
[False, True, True, False],
[False, False, True, False],
[False, False, False, False],
[True, True, True, False],
[True, True, True, True],
],
2,
)
Expand Down
25 changes: 18 additions & 7 deletions poc/tests/test_flp_bbcggi19.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def test(self) -> None:

# Successful cases:
cases = [
(flp.encode([0, 0, 0, 0]), True),
(flp.encode([0, 1, 0, 0]), True),
(flp.encode([0, 1, 1, 0]), True),
(flp.encode([1, 1, 0, 0]), True),
(flp.encode([False, False, False, False]), True),
(flp.encode([False, True, False, False]), True),
(flp.encode([False, True, True, False]), True),
(flp.encode([True, True, False, False]), True),
]
# Failure cases: too many number of 1s, should fail weight check.
cases += [
Expand All @@ -114,15 +114,26 @@ def test(self) -> None:
for i in range(valid.max_weight + 1, valid.length + 1)
]
# Failure case: pass count check but fail bit check.
cases += [(flp.encode([flp.field.MODULUS - 1, 1, 0, 0]), False)]
cases += [
(
[
flp.field(flp.field.MODULUS - 1),
flp.field(1),
flp.field(0),
flp.field(0),
]
+ [flp.field(0)] * valid.bits_for_weight,
False
)
]
self.run_flp_test(flp, cases)

def test_small(self) -> None:
flp = FlpBBCGGI19(MultihotCountVec(Field128, 1, 1, 1))

self.run_flp_test(flp, [
(flp.encode([0]), True),
(flp.encode([1]), True),
(flp.encode([False]), True),
(flp.encode([True]), True),
([flp.field(0), flp.field(1337)], False),
([flp.field(1), flp.field(0)], False),
])
Expand Down
30 changes: 24 additions & 6 deletions poc/tests/test_vdaf_prio3.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,30 @@ def test(self) -> None:
# chunk_length = 2.
prio3 = Prio3MultihotCountVec(2, 4, 2, 2)
self.assertEqual(prio3.ID, 0x00000005)
self.run_vdaf_test(prio3, None, [[0, 0, 0, 0]], [0, 0, 0, 0])
self.run_vdaf_test(prio3, None, [[0, 1, 0, 0]], [0, 1, 0, 0])
self.run_vdaf_test(prio3, None, [[0, 1, 1, 0]], [0, 1, 1, 0])
self.run_vdaf_test(
prio3, None, [[0, 1, 1, 0], [0, 1, 0, 1]], [0, 2, 1, 1])
self.run_vdaf_test(prio3, None, [[0, 1, 1, 0]], [0, 1, 1, 0])
prio3,
None,
[[False, False, False, False]],
[0, 0, 0, 0],
)
self.run_vdaf_test(
prio3,
None,
[[False, True, False, False]],
[0, 1, 0, 0],
)
self.run_vdaf_test(
prio3,
None,
[[False, True, True, False]],
[0, 1, 1, 0],
)
self.run_vdaf_test(
prio3,
None,
[[False, True, True, False], [False, True, False, True]],
[0, 2, 1, 1],
)

def test_3_shares(self) -> None:
# Prio3MultihotCountVec with length = 11, max_weight = 5,
Expand All @@ -149,7 +167,7 @@ def test_3_shares(self) -> None:
self.run_vdaf_test(
prio3,
None,
[[1] * 5 + [0] * 6],
[[True] * 5 + [False] * 6],
[1] * 5 + [0] * 6,
)

Expand Down
10 changes: 5 additions & 5 deletions poc/vdaf_poc/flp_bbcggi19.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,12 +738,12 @@ def test_vec_set_type_param(self, test_vec: dict[str, Any]) -> list[str]:
# NOTE: This class is excerpted in the document. Its width should be
# limited to 69 columns to avoid warnings from xml2rfc.
# ===================================================================
class MultihotCountVec(Valid[list[int], list[int], F]):
class MultihotCountVec(Valid[list[bool], list[int], F]):
# REMOVE ME
"""
A validity circuit that checks each Client's measurement is a bit
vector with at most `max_weight` number of 1s. We call the number
of 1s in the vector the vector's "weight".
vector with at most `max_weight` number of true values. We call
the number of true values in the vector the vector's "weight".
The circuit determines whether the weight of the vector is at most
`max_weight` as follows. First, it computes the weight of the
Expand Down Expand Up @@ -818,12 +818,12 @@ def __init__(self,
self.OUTPUT_LEN = self.length
self.JOINT_RAND_LEN = self.GADGET_CALLS[0]

def encode(self, measurement: list[int]) -> list[F]:
def encode(self, measurement: list[bool]) -> list[F]:
if len(measurement) != self.length:
raise ValueError('invalid Client measurement length')

# The first part is the vector of counters.
count_vec = list(map(self.field, measurement))
count_vec = [self.field(int(x)) for x in measurement]

# The second part is the reported weight.
weight_reported = sum(count_vec, self.field(0))
Expand Down
2 changes: 1 addition & 1 deletion poc/vdaf_poc/vdaf_prio3.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def __init__(
)


class Prio3MultihotCountVec(Prio3[list[int], list[int], Field128]):
class Prio3MultihotCountVec(Prio3[list[bool], list[int], Field128]):
ID = 0x00000005
xof = XofTurboShake128
VERIFY_KEY_SIZE = xof.SEED_SIZE
Expand Down
8 changes: 4 additions & 4 deletions test_vec/vdaf/Prio3MultihotCountVec_0.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"
],
"measurement": [
0,
1,
1,
0
false,
true,
true,
false
],
"nonce": "000102030405060708090a0b0c0d0e0f",
"out_shares": [
Expand Down
20 changes: 10 additions & 10 deletions test_vec/vdaf/Prio3MultihotCountVec_1.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@
"808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebf"
],
"measurement": [
0,
1,
0,
0,
0,
0,
0,
0,
0,
1
false,
true,
false,
false,
false,
false,
false,
false,
false,
true
],
"nonce": "000102030405060708090a0b0c0d0e0f",
"out_shares": [
Expand Down
40 changes: 20 additions & 20 deletions test_vec/vdaf/Prio3MultihotCountVec_2.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"
],
"measurement": [
0,
1,
1,
0
false,
true,
true,
false
],
"nonce": "000102030405060708090a0b0c0d0e0f",
"out_shares": [
Expand Down Expand Up @@ -59,10 +59,10 @@
"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"
],
"measurement": [
0,
0,
1,
0
false,
false,
true,
false
],
"nonce": "000102030405060708090a0b0c0d0e0f",
"out_shares": [
Expand Down Expand Up @@ -97,10 +97,10 @@
"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"
],
"measurement": [
0,
0,
0,
0
false,
false,
false,
false
],
"nonce": "000102030405060708090a0b0c0d0e0f",
"out_shares": [
Expand Down Expand Up @@ -135,10 +135,10 @@
"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"
],
"measurement": [
1,
1,
1,
0
true,
true,
true,
false
],
"nonce": "000102030405060708090a0b0c0d0e0f",
"out_shares": [
Expand Down Expand Up @@ -173,10 +173,10 @@
"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"
],
"measurement": [
1,
1,
1,
1
true,
true,
true,
true
],
"nonce": "000102030405060708090a0b0c0d0e0f",
"out_shares": [
Expand Down

0 comments on commit ad3b6dd

Please sign in to comment.