Skip to content

Commit

Permalink
Improve circuit for Prio3Sum (*)
Browse files Browse the repository at this point in the history
Replace `Sum` with `SumOfRangeCheckedInputs`. The new circuit has two
improvements:

1. No joint randomness is used, so it's safe to use a smaller field.

2. The largest measurement can be any positive integer, not just a power
   of two.

Update `Prio3Sum` accordingly by switching from `Field128` to `Field64`.

Co-authored-by: David Cook <[email protected]>
  • Loading branch information
cjpatton and divergentdave committed Aug 27, 2024
1 parent 63335a2 commit 584c799
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 177 deletions.
103 changes: 66 additions & 37 deletions draft-irtf-cfrg-vdaf.md
Original file line number Diff line number Diff line change
Expand Up @@ -3663,37 +3663,61 @@ way. The parameters for this circuit are summarized below.

### Prio3Sum

| Parameter | Value |
|:------------------|:-------------------------------------------|
| `Valid` | `Sum(Field128, bits)` (this section) |
| `Field` | `Field128` ({{fields}}) |
| `PROOFS` | `1` |
| `Xof` | `XofTurboShake128` ({{xof-turboshake128}}) |
| Parameter | Value |
|:------------------|:-----------------------------------------------|
| `Valid` | `Sum(Field64, max_measurement)` (this section) |
| `Field` | `Field64` ({{fields}}) |
| `PROOFS` | `1` |
| `Xof` | `XofTurboShake128` ({{xof-turboshake128}}) |
{: title="Parameters for Prio3Sum."}

The next instance of Prio3 supports summing of integers in a pre-determined
range. Each measurement is an integer in range `[0, 2^bits)`, where `bits` is an
associated parameter.
range. Each measurement is an integer in `range(max_measurement+1)`, where
`max_measurement` is an associated parameter equal to the largest valid
measurement.

The range check is accomplished by encoding the measurement as a bit vector,
encoding the measurement plus an offset as a bit vector, then checking that the
two encoded integers are consistent. Let

* `bits = max_measurement.bit_length()`, the number of bits needed to encode
the largest valid measurement

* `offset = 2**bits - 1 - max_measurement`

The first bit-encoded integer is the measurement itself. Note that only
measurements between `0` and `2**bits - 1` can be encoded this way with
`bits` bits. The second bit-encoded integer is the sum of the measurement and
`offset`. Observe that this sum can only be encoded this way if it is between `0`
and `2**bits - 1`, which implies that the measurement is between `-offset` and
`max_measurement`.

The validity circuit is denoted `Sum`. The measurement is encoded as a
length-`bits` vector of field elements, where the `l`th element of the vector
represents the `l`th bit of the summand:
The circuit, denoted `Sum`, first checks that each entry of both bit vectors is
a one or a zero. It then decodes both the measurement and the offset
measurement, and subtracts `offset` from the latter. It then checks if these
two values are equal. Since both the measurement and the measurement plus
`offset` are in the same range of `0` to `2**bits - 1`, this means that the
measurement itself is between `0` and `max_measurement`.

The measurement is encoded in `2*bits` field elements as follows:

~~~ python
def encode(self, measurement: int) -> list[F]:
if 0 > measurement or measurement >= 2 ** self.MEAS_LEN:
raise ValueError('measurement out of range')

return self.field.encode_into_bit_vector(measurement,
self.MEAS_LEN)
encoded = []
encoded += self.field.encode_into_bit_vector(
measurement,
self.bits
)
encoded += self.field.encode_into_bit_vector(
measurement + self.offset.as_unsigned(),
self.bits
)
return encoded

def truncate(self, meas: list[F]) -> list[F]:
return [self.field.decode_from_bit_vector(meas)]
return [self.field.decode_from_bit_vector(meas[:self.bits])]

def decode(
self,
output: list[F],
_num_measurements: int) -> int:
def decode(self, output: list[F], _num_measurements: int) -> int:
return output[0].as_unsigned()
~~~

Expand All @@ -3713,26 +3737,31 @@ def eval(
self,
meas: list[F],
joint_rand: list[F],
_num_shares: int) -> list[F]:
num_shares: int) -> list[F]:
self.check_valid_eval(meas, joint_rand)
out = self.field(0)
r = joint_rand[0]
shares_inv = self.field(num_shares).inv()

out = []
for b in meas:
out += r * self.GADGETS[0].eval(self.field, [b])
r *= joint_rand[0]
return [out]
out.append(self.GADGETS[0].eval(self.field, [b]))

range_check = self.offset * shares_inv + \
self.field.decode_from_bit_vector(meas[:self.bits]) - \
self.field.decode_from_bit_vector(meas[self.bits:])
out.append(range_check)
return out
~~~

| Parameter | Value |
|:------------------|:--------------------------|
| `GADGETS` | `[Range2]` |
| `GADGET_CALLS` | `[bits]` |
| `MEAS_LEN` | `bits` |
| `OUTPUT_LEN` | `1` |
| `JOINT_RAND_LEN` | `1` |
| `EVAL_OUTPUT_LEN` | `1` |
| `Measurement` | `int` in `range(2**bits)` |
| `AggResult` | `int` |
| Parameter | Value |
|:------------------|:----------------------------------|
| `GADGETS` | `[Range2]` |
| `GADGET_CALLS` | `[2*bits]` |
| `MEAS_LEN` | `2*bits` |
| `OUTPUT_LEN` | `1` |
| `JOINT_RAND_LEN` | `0` |
| `EVAL_OUTPUT_LEN` | `2*bits + 1` |
| `Measurement` | `int` in `range(max_measurement)` |
| `AggResult` | `int` |
{: title="Parameters of validity circuit Sum."}

### Prio3SumVec
Expand Down
4 changes: 2 additions & 2 deletions poc/gen_test_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ def gen_test_vec_for_xof(cls: type[Xof]) -> None:
)
gen_test_vec_for_vdaf(
TEST_VECTOR_PATH,
vdaf_prio3.Prio3Sum(2, 8),
vdaf_prio3.Prio3Sum(2, 255),
None,
[100],
0,
)
gen_test_vec_for_vdaf(
TEST_VECTOR_PATH,
vdaf_prio3.Prio3Sum(3, 8),
vdaf_prio3.Prio3Sum(3, 255),
None,
[100],
1,
Expand Down
21 changes: 5 additions & 16 deletions poc/tests/test_flp_bbcggi19.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from vdaf_poc.field import Field64, Field96, Field128, NttField
from vdaf_poc.flp_bbcggi19 import (Count, FlpBBCGGI19, Histogram, Mul,
MultihotCountVec, PolyEval, Range2, Sum,
SumOfRangeCheckedInputs, SumVec, Valid)
SumVec, Valid)
from vdaf_poc.test_utils import TestFlpBBCGGI19

Measurement = TypeVar("Measurement")
Expand Down Expand Up @@ -49,7 +49,7 @@ class TestAverage(Sum):
down.
"""

def decode(self, output: list[Field128], num_measurements: int) -> int:
def decode(self, output: list[Field64], num_measurements: int) -> int:
total = super().decode(output, num_measurements)
return total // num_measurements

Expand All @@ -66,23 +66,12 @@ def test(self) -> None:

class TestSum(TestFlpBBCGGI19):
def test(self) -> None:
flp = FlpBBCGGI19(Sum(Field128, 10))
self.run_flp_test(flp, [
(flp.encode(0), True),
(flp.encode(100), True),
(flp.encode(2 ** 10 - 1), True),
(flp.field.rand_vec(10), False),
])
self.run_encode_truncate_decode_test(flp, [0, 100, 2 ** 10 - 1])


class TestSumOfRangeCheckedInputs(TestFlpBBCGGI19):
def test(self) -> None:
flp = FlpBBCGGI19(SumOfRangeCheckedInputs(Field128, 10_000))
flp = FlpBBCGGI19(Sum(Field64, 10_000))
self.run_flp_test(flp, [
(flp.encode(0), True),
(flp.encode(1337), True),
(flp.encode(9_999), True),
(flp.encode(9999), True),
(flp.encode(10000), True),
(flp.field.zeros(flp.MEAS_LEN), False),
])

Expand Down
6 changes: 3 additions & 3 deletions poc/tests/test_vdaf_prio3.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def test_3_shares(self) -> None:

class TestPrio3Sum(TestVdaf):
def test(self) -> None:
prio3 = Prio3Sum(2, 8)
prio3 = Prio3Sum(2, 147)
self.assertEqual(prio3.ID, 0x00000001)
self.run_vdaf_test(prio3, None, [0, 147, 1, 0, 11, 0], 159)
self.run_vdaf_test(prio3, None, [100], 100)

def test_3_shares(self) -> None:
prio3 = Prio3Sum(3, 8)
prio3 = Prio3Sum(3, 100)
self.run_vdaf_test(prio3, None, [100], 100)


Expand Down Expand Up @@ -156,7 +156,7 @@ def test_3_shares(self) -> None:

class TestPrio3Average(TestVdaf):
def test(self) -> None:
prio3 = Prio3Average(2, 3)
prio3 = Prio3Average(2, 5)
self.run_vdaf_test(prio3, None, [1, 5, 1, 1, 4, 1, 3, 2], 2)

def test_is_valid(self) -> None:
Expand Down
106 changes: 19 additions & 87 deletions poc/vdaf_poc/flp_bbcggi19.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,76 +629,6 @@ def decode(self, output: list[F], _num_measurements: int) -> int:
return output[0].as_unsigned()


class Sum(
Valid[
int, # Measurement, `range(2 ** self.bits)`
int, # AggResult
F,
]):
GADGETS: list[Gadget[F]] = [Range2()]
JOINT_RAND_LEN = 1
OUTPUT_LEN = 1
EVAL_OUTPUT_LEN = 1

# Class object for the field.
field: type[F]

def __init__(self, field: type[F], bits: int):
"""
Instantiate an instace of the `Sum` circuit for measurements in range `[0,
2^bits)`.
"""

if 2 ** bits >= field.MODULUS:
raise ValueError('bit size exceeds field modulus')

self.field = field
self.GADGET_CALLS = [bits]
self.MEAS_LEN = bits

# NOTE: This method is excerpted in the document, de-indented. Its
# width should be limited to 69 columns after de-indenting, or 73
# columns before de-indenting, to avoid warnings from xml2rfc.
# ===================================================================
def eval(
self,
meas: list[F],
joint_rand: list[F],
_num_shares: int) -> list[F]:
self.check_valid_eval(meas, joint_rand)
out = self.field(0)
r = joint_rand[0]
for b in meas:
out += r * self.GADGETS[0].eval(self.field, [b])
r *= joint_rand[0]
return [out]

# NOTE: The encode(), truncate(), and decode() methods are excerpted
# in the document, de-indented. Their width should be limited to 69
# columns after de-indenting, or 73 columns before de-indenting, to
# avoid warnings from xml2rfc.
# ===================================================================
def encode(self, measurement: int) -> list[F]:
if 0 > measurement or measurement >= 2 ** self.MEAS_LEN:
raise ValueError('measurement out of range')

return self.field.encode_into_bit_vector(measurement,
self.MEAS_LEN)

def truncate(self, meas: list[F]) -> list[F]:
return [self.field.decode_from_bit_vector(meas)]

def decode(
self,
output: list[F],
_num_measurements: int) -> int:
return output[0].as_unsigned()

def test_vec_set_type_param(self, test_vec: dict[str, Any]) -> list[str]:
test_vec['bits'] = int(self.MEAS_LEN)
return ['bits']


class Histogram(
Valid[
int, # Measurement, `range(length)`
Expand Down Expand Up @@ -1092,8 +1022,7 @@ def test_vec_set_type_param(self, test_vec: dict[str, Any]) -> list[str]:
return ['length', 'bits', 'chunk_length']


# TODO(issue #306) Replace `Sum` with this type.
class SumOfRangeCheckedInputs(
class Sum(
Valid[
int, # Measurement, `range(self.max_measurement + 1)`
int, # AggResult
Expand All @@ -1106,33 +1035,36 @@ class SumOfRangeCheckedInputs(

def __init__(self, field: type[F], max_measurement: int):
"""
Similar to `Sum` but with an arbitrary bound.
The circuit checks that the measurement is in
A circuit that checks that the measurement is in
`range(max_measurement+1)`. This is accomplished by encoding the
measurement in a way that ensures it is in range, then comparing the
reported measurement to the range checked measurement.
measurement as a bit vector, encoding the measurement plus an
offset as a bit vector, then checking that the two encoded
integers are consistent.
Let
- `bits = max_measurement.bit_length()`
- `offset = 2**bits - 1 - max_measurement`
The range checked measurement is the bit-encoding of `offset` plus the
measurement. Observe that only measurements in at most
`max_measurement` can be encoded with `bits` bits.
To do the range check, the circuit first checks that each
entry of this bit vector is a one or a zero. It then decodes
it and subtracts it from `offset` plus the reported value.
Since the range checked measurement is in the correct range,
equality implies that the reported measurement is as well.
The first bit-encoded integer is the measurement itself. Note
that only measurements between `0` and `2**bits - 1` can be
encoded this way with `bits` bits. The second bit-encoded integer
is the sum of the measurement and `offset`. Observe that only
measurements between `-offset` and `max_measurement` inclusive
can be encoded this way with `bits` bits.
To do the range check, the circuit first checks that each entry
of both bit vectors is a one or a zero. It then decodes both the
measurement and the offset measurement, and subtracts `offset`
from the latter. It then checks if these two values are equal.
Since both the measurement and the measurement plus `offset` are
in the same range of `0` to `2**bits - 1`, this means that the
measurement itself is between `0` and `max_measurement`.
"""
self.field = field
self.bits = max_measurement.bit_length()
self.offset = self.field(2 ** self.bits - 1 - max_measurement)
self.max_measurement = max_measurement

if 2 ** self.bits >= self.field.MODULUS:
raise ValueError('bound exceeds field modulus')

Expand Down
8 changes: 4 additions & 4 deletions poc/vdaf_poc/vdaf_prio3.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,17 +560,17 @@ def __init__(self, shares: int):
super().__init__(shares, flp, 1)


class Prio3Sum(Prio3[int, int, Field128]):
class Prio3Sum(Prio3[int, int, Field64]):
ID = 0x00000001
xof = XofTurboShake128
VERIFY_KEY_SIZE = xof.SEED_SIZE

# Name of the VDAF, for use in test vector filenames.
test_vec_name = 'Prio3Sum'

def __init__(self, shares: int, bits: int):
flp = flp_bbcggi19.FlpBBCGGI19[int, int, Field128](
flp_bbcggi19.Sum(Field128, bits)
def __init__(self, shares: int, max_measurement: int):
flp = flp_bbcggi19.FlpBBCGGI19[int, int, Field64](
flp_bbcggi19.Sum(Field64, max_measurement)
)
super().__init__(shares, flp, 1)

Expand Down
Loading

0 comments on commit 584c799

Please sign in to comment.