From b6aa65086be8920493d33b8dbfc0dce04c56b4f0 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Mon, 7 Oct 2024 11:27:37 -0700 Subject: [PATCH] Prio3: Improve circuit specifications Rework the spec of each Prio3 variant by listing the entire class of each circuit. Also, list the specification of each gadget in the appendix. * Prio3Sum: Use `PolyEval` instead of `Range2`, as this is a much more useful gadget to list in the doc. * Say "variant" instead of "instance". * Document missing operations on polynomials. * Modify the constructor for `Field` elements by interpreting negative integers as additive inverses. This is necessary in order to allow expressing polynomial coefficients in the `PolyEval` circuit independent of the field modulus. --- draft-irtf-cfrg-vdaf.md | 927 +++++++++++++++++++-------------- poc/tests/test_flp_bbcggi19.py | 12 +- poc/vdaf_poc/field.py | 5 +- poc/vdaf_poc/flp_bbcggi19.py | 265 ++++------ 4 files changed, 653 insertions(+), 556 deletions(-) diff --git a/draft-irtf-cfrg-vdaf.md b/draft-irtf-cfrg-vdaf.md index b12ac8cb..80838e06 100644 --- a/draft-irtf-cfrg-vdaf.md +++ b/draft-irtf-cfrg-vdaf.md @@ -830,11 +830,22 @@ Some common functionalities: overriding the start of the range and overriding the step between successive output values. +* `poly_eval(field: type[F], p: list[F], x: F) -> F` returns the result of + evaluating the polynomial, `p(x)`. The coefficients of polynomials are stored + in lists in ascending order of degree, starting with the constant coefficient. + * `poly_interp(field: type[F], inputs: list[F], outputs: list[F]) -> list[F]` returns the coefficients of the lowest degree polynomial `p` for which `p(input[k]) == output[k]` for all `k`. Normally this will be computed using the Number Theoretic Transform (NTT) {{SML24}}. +* `poly_mul(field: type[F], p: list[F], q: list[F]) -> list[F]` returns + the product of two polynomials. + +* `poly_strip(field: type[F], p: list[F]) -> list[F]` strips the zeros from the + end of the input polynomial. That is, it returns `p[:i]` where `i` is the + index of the last non-zero coefficient of `p`. + # Overview ~~~ @@ -3799,12 +3810,13 @@ def decide(self, verifier: list[F]) -> bool: return True ~~~ -## Instantiations {#prio3-instantiations} +## Variants {#prio3-instantiations} -This section specifies instantiations of Prio3 for various measurement types. -Each is determined by a field ({{field}}), a validity circuit -({{flp-bbcggi19-valid}}), an XOF ({{xof}}). and the number of proofs to +This section specifies instantiations of Prio3 for various aggregation tasks. +Each variant is determined by a field ({{field}}), a validity circuit +({{flp-bbcggi19-valid}}), an XOF ({{xof}}), and the number of proofs to generate and verify. Test vectors for each can be found in {{test-vectors}}. +All gadgets are listed in {{gadgets}}. ### Prio3Count @@ -3816,47 +3828,45 @@ generate and verify. Test vectors for each can be found in {{test-vectors}}. | `Xof` | `XofTurboShake128` ({{xof-turboshake128}}) | {: title="Parameters for Prio3Count."} -Our first instance of Prio3 is for a simple counter: each measurement is either -one or zero and the aggregate result is the sum of the measurements. - -Its validity circuit, denoted `Count`, uses the following degree-2, arity-2 -gadget, denoted `Mul`: +Our first variant of Prio3 is for a simple counter: each measurement is either +one or zero and the aggregate result is the sum of the measurements. Its +validity circuit uses the multiplication gadget `Mul` specified in +{{gadget-mul}}, which takes two inputs and multiplies them. The circuit is +specified below. ~~~ python -def eval(self, _field: type[F], inp: list[F]) -> F: - self.check_gadget_eval(inp) - return inp[0] * inp[1] -~~~ +class Count(Valid[int, int, F]): + GADGETS: list[Gadget[F]] = [Mul()] + GADGET_CALLS = [1] + MEAS_LEN = 1 + JOINT_RAND_LEN = 0 + OUTPUT_LEN = 1 + EVAL_OUTPUT_LEN = 1 -The call to `check_gadget_eval()` raises an error if the length of the input is -not equal to the gadget's `ARITY` parameter. + # Class object for the field. + field: type[F] -The `Count` validity circuit is defined as + def __init__(self, field: type[F]): + self.field = field -~~~ python -def eval( - self, - meas: list[F], - joint_rand: list[F], - _num_shares: int) -> list[F]: - squared = self.GADGETS[0].eval(self.field, [meas[0], meas[0]]) - return [squared - meas[0]] -~~~ + def eval( + self, + meas: list[F], + joint_rand: list[F], + _num_shares: int) -> list[F]: + squared = self.GADGETS[0].eval(self.field, + [meas[0], meas[0]]) + return [squared - meas[0]] + + def encode(self, measurement: int) -> list[F]: + return [self.field(measurement)] -The measurement is encoded and decoded as a singleton vector in the natural -way. The parameters for this circuit are summarized below. - -| Parameter | Value | -|:------------------|:-----------------------| -| `GADGETS` | `[Mul]` | -| `GADGET_CALLS` | `[1]` | -| `MEAS_LEN` | `1` | -| `OUTPUT_LEN` | `1` | -| `JOINT_RAND_LEN` | `0` | -| `EVAL_OUTPUT_LEN` | `1` | -| `Measurement` | `int` in `[0, 2)` | -| `AggResult` | `int` | -{: title="Parameters of validity circuit Count."} + def truncate(self, meas: list[F]) -> list[F]: + return meas + + def decode(self, output: list[F], _num_measurements: int) -> int: + return output[0].as_unsigned() +~~~ ### Prio3Sum @@ -3868,10 +3878,9 @@ way. The parameters for this circuit are summarized below. | `Xof` | `XofTurboShake128` ({{xof-turboshake128}}) | {: title="Parameters for Prio3Sum."} -The next instance of Prio3 supports summing of integers in a pre-determined +The next variant of Prio3 supports summing of integers in a pre-determined range. Each measurement is an integer in the range `[0, max_measurement]`, -where `max_measurement` is an associated parameter equal to the largest valid -measurement. +where `max_measurement` defines the largest valid measurement. The range check is accomplished by encoding the measurement as a bit vector, encoding the measurement plus an offset as a bit vector, then checking that the @@ -3889,77 +3898,69 @@ bits. The second bit-encoded integer is the sum of the measurement and `0` and `2^bits - 1`, which implies that the measurement is between `-offset` and `max_measurement`. -The circuit, denoted `Sum`, first checks that each entry of both bit vectors is -a one or a zero. It then decodes both the measurement and the offset -measurement, and subtracts `offset` from the latter. It then checks if these -two values are equal. Since both the measurement and the measurement plus -`offset` are in the same range of `[0, 2^bits)`, this means that the -measurement itself is between `0` and `max_measurement`. - -The measurement is encoded in `2*bits` field elements as follows: - -~~~ python -def encode(self, measurement: int) -> list[F]: - encoded = [] - encoded += self.field.encode_into_bit_vector( - measurement, - self.bits - ) - encoded += self.field.encode_into_bit_vector( - measurement + self.offset.as_unsigned(), - self.bits - ) - return encoded - -def truncate(self, meas: list[F]) -> list[F]: - return [self.field.decode_from_bit_vector(meas[:self.bits])] - -def decode(self, output: list[F], _num_measurements: int) -> int: - return output[0].as_unsigned() -~~~ +The circuit first checks that each entry of both bit vectors is a one or a +zero. It then decodes both the measurement and the offset measurement, and +subtracts `offset` from the latter. It then checks if these two values are +equal. Since both the measurement and the measurement plus `offset` are in the +same range of `[0, 2^bits)`, this means that the measurement itself is +between `0` and `max_measurement`. -The validity circuit checks that the input consists of ones and zeros. Its -gadget, denoted `Range2`, is the degree-2, arity-1 gadget defined as +The circuit uses the polynomial-evaluation gadget `PolyEval` specified in +{{gadget-poly-eval}}. The polynomial is `p(x) = x^2 - x`, which is equal to `0` +if and only if `x` is in `[0, 2)`. The complete circuit is specified below. -~~~ python -def eval(self, _field: type[F], inp: list[F]) -> F: - self.check_gadget_eval(inp) - return inp[0] * inp[0] - inp[0] ~~~ +class Sum(Valid[int, int, F]): + GADGETS: list[Gadget[F]] = [PolyEval([0, -1, 1])] + JOINT_RAND_LEN = 0 + OUTPUT_LEN = 1 + field: type[F] + + def __init__(self, field: type[F], max_measurement: int): + self.field = field + self.bits = max_measurement.bit_length() + self.offset = self.field(2**self.bits - 1 - max_measurement) + self.max_measurement = max_measurement + self.GADGET_CALLS = [2 * self.bits] + self.MEAS_LEN = 2 * self.bits + self.EVAL_OUTPUT_LEN = 2 * self.bits + 1 + + def eval( + self, + meas: list[F], + joint_rand: list[F], + num_shares: int) -> list[F]: + shares_inv = self.field(num_shares).inv() + + out = [] + for b in meas: + out.append(self.GADGETS[0].eval(self.field, [b])) + + range_check = self.offset * shares_inv + \ + self.field.decode_from_bit_vector(meas[:self.bits]) - \ + self.field.decode_from_bit_vector(meas[self.bits:]) + out.append(range_check) + return out + + def encode(self, measurement: int) -> list[F]: + encoded = [] + encoded += self.field.encode_into_bit_vector( + measurement, + self.bits + ) + encoded += self.field.encode_into_bit_vector( + measurement + self.offset.as_unsigned(), + self.bits + ) + return encoded -The `Sum` validity circuit is defined as + def truncate(self, meas: list[F]) -> list[F]: + return [self.field.decode_from_bit_vector(meas[:self.bits])] -~~~ python -def eval( - self, - meas: list[F], - joint_rand: list[F], - num_shares: int) -> list[F]: - shares_inv = self.field(num_shares).inv() - - out = [] - for b in meas: - out.append(self.GADGETS[0].eval(self.field, [b])) - - range_check = self.offset * shares_inv + \ - self.field.decode_from_bit_vector(meas[:self.bits]) - \ - self.field.decode_from_bit_vector(meas[self.bits:]) - out.append(range_check) - return out + def decode(self, output: list[F], _num_measurements: int) -> int: + return output[0].as_unsigned() ~~~ -| Parameter | Value | -|:------------------|:------------------------------------| -| `GADGETS` | `[Range2]` | -| `GADGET_CALLS` | `[2*bits]` | -| `MEAS_LEN` | `2*bits` | -| `OUTPUT_LEN` | `1` | -| `JOINT_RAND_LEN` | `0` | -| `EVAL_OUTPUT_LEN` | `2*bits + 1` | -| `Measurement` | `int` in `[0, max_measurement]` | -| `AggResult` | `int` | -{: title="Parameters of validity circuit Sum."} - ### Prio3SumVec | Parameter | Value | @@ -3970,148 +3971,140 @@ def eval( | `Xof` | `XofTurboShake128` ({{xof-turboshake128}}) | {: title="Parameters for Prio3SumVec."} -This instance of Prio3 supports summing a vector of integers. It has three -parameters, `length`, `bits`, and `chunk_length`. Each measurement is a vector +This instance of Prio3 supports summing vectors of integers. It has three +parameters: `length`, `bits`, and `chunk_length`. Each measurement is a vector of positive integers with length equal to the `length` parameter. Each element of the measurement is an integer in the range `[0, 2^bits)`. It is RECOMMENDED to set `chunk_length` to an integer near the square root of `length * bits` (see {{parallel-sum-chunk-length}}). -The validity circuit is denoted `SumVec`. Measurements are encoded as a vector -of field elements with length `length * bits`. The field elements in the -encoded vector represent all the bits of the measurement vector's elements, -consecutively, in LSB to MSB order: +The circuit is denoted `SumVec`. Each measurements is encoded as a vector of +field elements with length `length * bits`. The field elements in the encoded +vector represent all the bits of the measurement vector's elements, +consecutively, in LSB to MSB order. -~~~ python -def encode(self, measurement: list[int]) -> list[F]: - if len(measurement) != self.length: - raise ValueError('incorrect measurement length') +The validity circuit uses the `ParallelSum` gadget in {{gadget-parallel-sum}}. +This gadget applies an arithmetic subcircuit to multiple inputs in parallel, +then returns the sum of the results. Along with the subcircuit, the +parallel-sum gadget is parameterized by an integer `count` specifying how many +times to call the subcircuit. It takes in a list of inputs and passes them +through to instances of the subcircuit in the same order. It returns the sum of +the subcircuit outputs. - encoded = [] - for val in measurement: - if val not in range(2**self.bits): - raise ValueError( - 'entry of measurement vector is out of range' - ) +Note that only the `ParallelSum` gadget itself, and not its subcircuit, +participates in the FLP's wire recording during evaluation, gadget consistency +proofs, and proof validation, even though the subcircuit is provided to +`ParallelSum` gadget as an implementation of the `Gadget` interface. - encoded += self.field.encode_into_bit_vector(val, self.bits) - return encoded +The `SumVec` validity circuit checks that the encoded measurement consists of +ones and zeros. Rather than use the `PolyEval` gadget on each element, as in +the `Sum` validity circuit, it instead uses `Mul` subcircuits ({{gadget-mul}}) +and "free" constant multiplication and addition gates to simultaneously +evaluate the same range check polynomial on each element, and multiply by a +constant. One of the two `Mul` subcircuit inputs is equal to a measurement +element multiplied by a power of one of the elements of the joint randomness +vector, and the other is equal to the same measurement element minus one. These +`Mul` subcircuits are evaluated by a `ParallelSum` gadget, and the results are +added up both within the `ParallelSum` gadget and after it. + +The complete circuit is specified below. -def truncate(self, meas: list[F]) -> list[F]: - truncated = [] - for i in range(self.length): - truncated.append(self.field.decode_from_bit_vector( - meas[i * self.bits: (i + 1) * self.bits] - )) - return truncated +~~~ python +class SumVec(Valid[list[int], list[int], F]): + EVAL_OUTPUT_LEN = 1 + length: int + bits: int + chunk_length: int + field: type[F] -def decode( - self, - output: list[F], - _num_measurements: int) -> list[int]: - return [x.as_unsigned() for x in output] -~~~ + def __init__(self, + field: type[F], + length: int, + bits: int, + chunk_length: int): + """ + Instantiate the `SumVec` circuit for measurements with + `length` elements, each in the range `[0, 2^bits)`. + """ + self.field = field + self.length = length + self.bits = bits + self.chunk_length = chunk_length + self.GADGETS = [ParallelSum(Mul(), chunk_length)] + self.GADGET_CALLS = [ + (length * bits + chunk_length - 1) // chunk_length + ] + self.MEAS_LEN = length * bits + self.OUTPUT_LEN = length + self.JOINT_RAND_LEN = self.GADGET_CALLS[0] -This validity circuit uses a `ParallelSum` gadget to achieve a smaller proof -size. This optimization for "parallel-sum circuits" is described in -{{BBCGGI19}}, section 4.4. Briefly, for circuits that add up the output of -multiple identical subcircuits, it is possible to achieve smaller proof sizes -(on the order of O(sqrt(MEAS_LEN)) instead of O(MEAS_LEN)) by packaging more -than one such subcircuit into a gadget. - -The `ParallelSum` gadget is parameterized with an arithmetic subcircuit, and a -`count` of how many times it evaluates that subcircuit. It takes in a list of -inputs and passes them through to instances of the subcircuit in the same order. -It returns the sum of the subcircuit outputs. Note that only the `ParallelSum` -gadget itself, and not its subcircuit, participates in the FLP's wire recording -during evaluation, gadget consistency proofs, and proof validation, even though -the subcircuit is provided to `ParallelSum` as an implementation of the -`Gadget` interface. + def eval( + self, + meas: list[F], + joint_rand: list[F], + num_shares: int) -> list[F]: + out = self.field(0) + shares_inv = self.field(num_shares).inv() + for i in range(self.GADGET_CALLS[0]): + r = joint_rand[i] + r_power = r + inputs: list[Optional[F]] + inputs = [None] * (2 * self.chunk_length) + for j in range(self.chunk_length): + index = i * self.chunk_length + j + if index < len(meas): + meas_elem = meas[index] + else: + meas_elem = self.field(0) + + inputs[j * 2] = r_power * meas_elem + inputs[j * 2 + 1] = meas_elem - shares_inv + + r_power *= r + + out += self.GADGETS[0].eval( + self.field, + cast(list[F], inputs), + ) -~~~ python -def eval(self, field: type[F], inp: list[F]) -> F: - self.check_gadget_eval(inp) - out = field(0) - for i in range(self.count): - start_index = i * self.subcircuit.ARITY - end_index = (i + 1) * self.subcircuit.ARITY - out += self.subcircuit.eval( - field, - inp[start_index:end_index], - ) - return out -~~~ + return [out] -The `SumVec` validity circuit checks that the encoded measurement consists of -ones and zeros. Rather than use the `Range2` gadget on each element, as in the -`Sum` validity circuit, it instead uses `Mul` subcircuits and "free" constant -multiplication and addition gates to simultaneously evaluate the same range -check polynomial on each element, and multiply by a constant. One of the two -`Mul` subcircuit inputs is equal to a measurement element multiplied by a power -of one of the elements of the joint randomness vector, and the other is equal -to the same measurement element minus one. These `Mul` subcircuits are -evaluated by a `ParallelSum` gadget, and the results are added up both within -the `ParallelSum` gadget and after it. + def encode(self, measurement: list[int]) -> list[F]: + encoded = [] + for val in measurement: + encoded += self.field.encode_into_bit_vector( + val, self.bits) + return encoded -~~~ python -def eval( - self, - meas: list[F], - joint_rand: list[F], - num_shares: int) -> list[F]: - out = self.field(0) - shares_inv = self.field(num_shares).inv() - for i in range(self.GADGET_CALLS[0]): - r = joint_rand[i] - r_power = r - inputs: list[Optional[F]] - inputs = [None] * (2 * self.chunk_length) - for j in range(self.chunk_length): - index = i * self.chunk_length + j - if index < len(meas): - meas_elem = meas[index] - else: - meas_elem = self.field(0) - - inputs[j * 2] = r_power * meas_elem - inputs[j * 2 + 1] = meas_elem - shares_inv - - r_power *= r - - out += self.GADGETS[0].eval( - self.field, - cast(list[F], inputs), - ) + def truncate(self, meas: list[F]) -> list[F]: + truncated = [] + for i in range(self.length): + truncated.append(self.field.decode_from_bit_vector( + meas[i * self.bits: (i + 1) * self.bits] + )) + return truncated - return [out] + def decode( + self, + output: list[F], + _num_measurements: int) -> list[int]: + return [x.as_unsigned() for x in output] ~~~ -| Parameter | Value | -|:------------------|:-------------------------------------------------------| -| `GADGETS` | `[ParallelSum(Mul(), chunk_length)]` | -| `GADGET_CALLS` | `[(length * bits + chunk_length - 1) // chunk_length]` | -| `MEAS_LEN` | `length * bits` | -| `OUTPUT_LEN` | `length` | -| `JOINT_RAND_LEN` | `GADGET_CALLS[0]` | -| `EVAL_OUTPUT_LEN` | `1` | -| `Measurement` | `list[int]`, each element in `[0, 2^bits)` | -| `AggResult` | `list[int]` | -{: title="Parameters of validity circuit SumVec."} - #### Selection of `ParallelSum` Chunk Length {#parallel-sum-chunk-length} The `chunk_length` parameter provides a trade-off between the arity of the -`ParallelSum` gadget and the number of times the gadget is called. The proof -length is asymptotically minimized when the chunk length is near the square root -of the length of the measurement. However, the relationship between VDAF -parameters and proof length is complicated, involving two forms of rounding (the -circuit pads the inputs to its last `ParallelSum` gadget call, up to the chunk -length, and proof system rounds the degree of wire polynomials -- determined by -the number of times a gadget is called -- up to the next power of two). -Therefore, the optimal choice of `chunk_length` for a concrete measurement size -will vary, and must be found through trial and error. Setting `chunk_length` -equal to the square root of the appropriate measurement length will result in -proofs up to 50% larger than the optimal proof size. +`ParallelSum` gadget ({{gadget-parallel-sum}}) and the number of times the +gadget is called. The proof length is asymptotically minimized when the chunk +length is near the square root of the length of the measurement. However, the +relationship between VDAF parameters and proof length is complicated, involving +two forms of rounding (the circuit pads the inputs to its last `ParallelSum` +gadget call, up to the chunk length, and proof system rounds the degree of wire +polynomials -- determined by the number of times a gadget is called -- up to +the next power of two). Therefore, the optimal choice of `chunk_length` for a +concrete measurement size will vary, and must be found through trial and error. +Setting `chunk_length` equal to the square root of the appropriate measurement +length will result in proofs up to 50% larger than the optimal proof size. ### Prio3Histogram @@ -4123,105 +4116,113 @@ proofs up to 50% larger than the optimal proof size. | `Xof` | `XofTurboShake128` ({{xof-turboshake128}}) | {: title="Parameters for Prio3Histogram."} -This instance of Prio3 allows for estimating the distribution of some quantity +This variant of Prio3 allows for estimating the distribution of some quantity by computing a simple histogram. Each measurement increments one histogram bucket, out of a set of fixed buckets. (Bucket indexing begins at `0`.) For -example, the buckets might quantize the real numbers, and each measurement would -report the bucket that the corresponding client's real-numbered value falls -into. The aggregate result counts the number of measurements in each bucket. +example, the buckets might quantize the real numbers, and each measurement +would report the bucket that the corresponding client's real-numbered value +falls into. The aggregate result counts the number of measurements in each +bucket. -The validity circuit is denoted `Histogram`. It has two parameters, `length`, -the number of histogram buckets, and `chunk_length`, which is used by by a +The validity circuit is denoted `Histogram`. It has two parameters: `length`, +the number of histogram buckets; and `chunk_length`, which is used by by a circuit optimization described below. It is RECOMMENDED to set `chunk_length` to an integer near the square root of `length` (see {{parallel-sum-chunk-length}}). The measurement is encoded as a one-hot vector representing the bucket into -which the measurement falls: - -~~~ python -def encode(self, measurement: int) -> list[F]: - encoded = [self.field(0)] * self.length - encoded[measurement] = self.field(1) - return encoded - -def truncate(self, meas: list[F]) -> list[F]: - return meas - -def decode( - self, - output: list[F], - _num_measurements: int) -> list[int]: - return [bucket_count.as_unsigned() for bucket_count in output] -~~~ - -The `Histogram` validity circuit checks for one-hotness in two steps, by -checking that the encoded measurement consists of ones and zeros, and by +which the measurement falls. The circuit checks for one-hotness in two steps, +by checking that the encoded measurement consists of ones and zeros, and by checking that the sum of all elements in the encoded measurement is equal to one. The individual checks constitute the output of the circuit. As in the `SumVec` validity circuit ({{prio3sumvec}}), the first part of the -validity circuit uses the `ParallelSum` gadget to perform range checks while -achieving a smaller proof size. The `ParallelSum` gadget uses `Mul` subcircuits -to evaluate a range check polynomial on each element, and includes an additional -constant multiplication. One of the two `Mul` subcircuit inputs is equal to a -measurement element multiplied by a power of an element of the joint randomness -vector, and the other is equal to the same measurement element minus one. The -results are added up both within the `ParallelSum` gadget and after it. +validity circuit uses the `ParallelSum` ({{gadget-parallel-sum}}) gadget to +perform range checks while achieving a smaller proof size. The `ParallelSum` +gadget uses `Mul` subcircuits ({{gadget-mul}}) to evaluate a range check +polynomial on each element, and includes an additional constant multiplication. +One of the two `Mul` subcircuit inputs is equal to a measurement element +multiplied by a power of an element of the joint randomness vector, and the +other is equal to the same measurement element minus one. The results are added +up both within the `ParallelSum` gadget and after it. ~~~ python -def eval( - self, - meas: list[F], - joint_rand: list[F], - num_shares: int) -> list[F]: - # Check that each bucket is one or zero. - range_check = self.field(0) - shares_inv = self.field(num_shares).inv() - for i in range(self.GADGET_CALLS[0]): - r = joint_rand[i] - r_power = r - inputs: list[Optional[F]] - inputs = [None] * (2 * self.chunk_length) - for j in range(self.chunk_length): - index = i * self.chunk_length + j - if index < len(meas): - meas_elem = meas[index] - else: - meas_elem = self.field(0) - - inputs[j * 2] = r_power * meas_elem - inputs[j * 2 + 1] = meas_elem - shares_inv - - r_power *= r - - range_check += self.GADGETS[0].eval( - self.field, - cast(list[F], inputs), - ) +class Histogram(Valid[int, list[int], F]): + EVAL_OUTPUT_LEN = 2 + field: type[F] + length: int + chunk_length: int - # Check that the buckets sum to 1. - sum_check = -shares_inv - for b in meas: - sum_check += b + def __init__(self, + field: type[F], + length: int, + chunk_length: int): + """ + Instantiate an instance of the `Histogram` circuit with the + given `length` and `chunk_length`. + """ + self.field = field + self.length = length + self.chunk_length = chunk_length + self.GADGETS = [ParallelSum(Mul(), chunk_length)] + self.GADGET_CALLS = [ + (length + chunk_length - 1) // chunk_length] + self.MEAS_LEN = self.length + self.OUTPUT_LEN = self.length + self.JOINT_RAND_LEN = self.GADGET_CALLS[0] + + def eval( + self, + meas: list[F], + joint_rand: list[F], + num_shares: int) -> list[F]: + # Check that each bucket is one or zero. + range_check = self.field(0) + shares_inv = self.field(num_shares).inv() + for i in range(self.GADGET_CALLS[0]): + r = joint_rand[i] + r_power = r + inputs: list[Optional[F]] + inputs = [None] * (2 * self.chunk_length) + for j in range(self.chunk_length): + index = i * self.chunk_length + j + if index < len(meas): + meas_elem = meas[index] + else: + meas_elem = self.field(0) + + inputs[j * 2] = r_power * meas_elem + inputs[j * 2 + 1] = meas_elem - shares_inv + + r_power *= r + + range_check += self.GADGETS[0].eval( + self.field, + cast(list[F], inputs), + ) - return [range_check, sum_check] -~~~ + # Check that the buckets sum to 1. + sum_check = -shares_inv + for b in meas: + sum_check += b -Note that this circuit depends on the number of shares into which the -measurement is sharded. This is provided to the FLP by Prio3. + return [range_check, sum_check] -| Parameter | Value | -|:------------------|:------------------------------------------------| -| `GADGETS` | `[ParallelSum(Mul(), chunk_length)]` | -| `GADGET_CALLS` | `[(length + chunk_length - 1) // chunk_length]` | -| `MEAS_LEN` | `length` | -| `OUTPUT_LEN` | `length` | -| `JOINT_RAND_LEN` | `GADGET_CALLS[0]` | -| `EVAL_OUTPUT_LEN` | `2` | -| `Measurement` | `int` | -| `AggResult` | `list[int]` | -{: title="Parameters of validity circuit Histogram."} + def encode(self, measurement: int) -> list[F]: + encoded = [self.field(0)] * self.length + encoded[measurement] = self.field(1) + return encoded + + def truncate(self, meas: list[F]) -> list[F]: + return meas + + def decode( + self, + output: list[F], + _num_measurements: int) -> list[int]: + return [bucket_count.as_unsigned() + for bucket_count in output] +~~~ ### Prio3MultihotCountVec @@ -4240,8 +4241,7 @@ Prio3MultihotCountVec to be composed with a randomized response mechanism, like {{EPK14}}, for providing differential privacy. (For example, each Client would set each entry to one with some small probability.) -Prio3MultihotCountVec uses XofTurboShake128 ({{xof-turboshake128}}) as its XOF. -Its validity circuit is denoted `MultihotCountVec`. It has three parameters: +The validity circuit is denoted `MultihotCountVec` and has three parameters: `length`, the number of of entries in the count vector; `max_weight`, the maximum number of non-zero entries (i.e., the weight must be at most `max_weight`); and `chunk_length`, used the same way as in {{prio3sumvec}} and @@ -4259,94 +4259,125 @@ be encoded with `bits_for_weight` bits. The verifier checks that each entry of the encoded measurement is a bit (i.e., either one or zero). It then decodes the reported weight and subtracts it from `offset + sum(count_vec)`, where `count_vec` is the count vector. The result is -zero if and only if the reported weight is equal to the true weight. - -Encoding, truncation, and decoding are defined as follows: +zero if and only if the reported weight is equal to the true weight. The +complete circuit is defined below. ~~~ python -def encode(self, measurement: list[int]) -> list[F]: - if len(measurement) != self.length: - raise ValueError('invalid Client measurement length') +class MultihotCountVec(Valid[list[int], list[int], F]): + EVAL_OUTPUT_LEN = 2 + field: type[F] - # The first part is the vector of counters. - count_vec = list(map(self.field, measurement)) + def __init__(self, + field: type[F], + length: int, + max_weight: int, + chunk_length: int): + """ + Instantiate an instance of the this circuit with the given + `length`, `max_weight`, and `chunk_length`. - # The second part is the reported weight. - weight_reported = sum(count_vec, self.field(0)) + Pre-conditions: - encoded = [] - encoded += count_vec - encoded += self.field.encode_into_bit_vector( - (self.offset + weight_reported).as_unsigned(), - self.bits_for_weight) - return encoded + - `length > 0` + - `0 < max_weight` and `max_weight <= length` + - `chunk_length > 0` + """ + self.field = field + + # Compute the number of bits to represent `max_weight`. + self.bits_for_weight = max_weight.bit_length() + self.offset = self.field( + 2**self.bits_for_weight - 1 - max_weight) + + # Make sure `offset + length` doesn't overflow the field + # modulus. Otherwise we may not correctly compute the sum + # measurement vector entries during circuit evaluation. + if self.field.MODULUS - self.offset.as_unsigned() <= length: + raise ValueError('length and max_weight are too large ' + 'for the current field size') + + self.length = length + self.max_weight = max_weight + self.chunk_length = chunk_length + self.GADGETS: list[Gadget[F]] = [ + ParallelSum(Mul(), chunk_length), + ] + self.GADGET_CALLS = [ + (length + self.bits_for_weight + chunk_length - 1) + // chunk_length + ] + self.MEAS_LEN = self.length + self.bits_for_weight + self.OUTPUT_LEN = self.length + self.JOINT_RAND_LEN = self.GADGET_CALLS[0] + + def eval( + self, + meas: list[F], + joint_rand: list[F], + num_shares: int) -> list[F]: + # Check that each entry in the input vector is one or zero. + range_check = self.field(0) + shares_inv = self.field(num_shares).inv() + for i in range(self.GADGET_CALLS[0]): + r = joint_rand[i] + r_power = r + inputs: list[Optional[F]] + inputs = [None] * (2 * self.chunk_length) + for j in range(self.chunk_length): + index = i * self.chunk_length + j + if index < len(meas): + meas_elem = meas[index] + else: + meas_elem = self.field(0) + + inputs[j * 2] = r_power * meas_elem + inputs[j * 2 + 1] = meas_elem - shares_inv + + r_power *= r + + range_check += self.GADGETS[0].eval( + self.field, + cast(list[F], inputs), + ) -def truncate(self, meas: list[F]) -> list[F]: - return meas[:self.length] + # Check that the weight `offset` plus the sum of the counters + # is equal to the value claimed by the Client. + count_vec = meas[:self.length] + weight = sum(count_vec, self.field(0)) + weight_reported = \ + self.field.decode_from_bit_vector(meas[self.length:]) + weight_check = self.offset*shares_inv + weight - \ + weight_reported -def decode( - self, - output: list[F], - _num_measurements: int) -> list[int]: - return [bucket_count.as_unsigned() for bucket_count in output] -~~~ + return [range_check, weight_check] -Circuit evaluation is defined as follows: + def encode(self, measurement: list[int]) -> list[F]: + if len(measurement) != self.length: + raise ValueError('invalid Client measurement length') -~~~ python -def eval( - self, - meas: list[F], - joint_rand: list[F], - num_shares: int) -> list[F]: - # Check that each entry in the input vector is one or zero. - range_check = self.field(0) - shares_inv = self.field(num_shares).inv() - for i in range(self.GADGET_CALLS[0]): - r = joint_rand[i] - r_power = r - inputs: list[Optional[F]] - inputs = [None] * (2 * self.chunk_length) - for j in range(self.chunk_length): - index = i * self.chunk_length + j - if index < len(meas): - meas_elem = meas[index] - else: - meas_elem = self.field(0) - - inputs[j * 2] = r_power * meas_elem - inputs[j * 2 + 1] = meas_elem - shares_inv - - r_power *= r - - range_check += self.GADGETS[0].eval( - self.field, - cast(list[F], inputs), - ) + # The first part is the vector of counters. + count_vec = list(map(self.field, measurement)) - # Check that the weight `offset` plus the sum of the counters - # is equal to the value claimed by the Client. - count_vec = meas[:self.length] - weight = sum(count_vec, self.field(0)) - weight_reported = \ - self.field.decode_from_bit_vector(meas[self.length:]) - weight_check = self.offset*shares_inv + weight - \ - weight_reported + # The second part is the reported weight. + weight_reported = sum(count_vec, self.field(0)) - return [range_check, weight_check] -~~~ + encoded = [] + encoded += count_vec + encoded += self.field.encode_into_bit_vector( + (self.offset + weight_reported).as_unsigned(), + self.bits_for_weight) + return encoded -| Parameter | Value | -|:------------------|:------------------------------------------------| -| `GADGETS` | `[ParallelSum(Mul(), chunk_length)]` | -| `GADGET_CALLS` | `[(length + bits_for_weight + chunk_length - 1) // chunk_length]` | -| `MEAS_LEN` | `length + bits_for_weight` | -| `OUTPUT_LEN` | `length` | -| `JOINT_RAND_LEN` | `GADGET_CALLS[0]` | -| `EVAL_OUTPUT_LEN` | `2` | -| `Measurement` | `list[int]` | -| `AggResult` | `list[int]` | -{: title="Parameters of validity circuit MultihotCountVec."} + def truncate(self, meas: list[F]) -> list[F]: + return meas[:self.length] + + def decode( + self, + output: list[F], + _num_measurements: int) -> list[int]: + return [bucket_count.as_unsigned() for + bucket_count in output] +~~~ # Poplar1 {#poplar1} @@ -5923,7 +5954,126 @@ collisions. Applications SHOULD prefer the use of registered code points. --- back -# FLP Gadget Wrappers {#gadget-wrappers} +# FLP Gadgets {#gadgets} + +This section defines gadgets used to specify variants of Prio3 defined in +{{prio3-instantiations}} and to construct FLPs as defined in {{flp-bbcggi19}}. + +## Multiplication {#gadget-mul} + +The multiplication gadget takes in a pair of inputs and multiplies them +together. The degree of this circuit is `2`. + +~~~ python +class Mul(Gadget[F]): + ARITY = 2 + DEGREE = 2 + + def eval(self, _field: type[F], inp: list[F]) -> F: + return inp[0] * inp[1] + + def eval_poly(self, + field: type[F], + inp_poly: list[list[F]]) -> list[F]: + return poly_mul(field, inp_poly[0], inp_poly[1]) +~~~ + +## Polynomial Evaluation {#gadget-poly-eval} + +The polynomial-evaluation gadget takes in a single input `x` and returns +`p(x)`, where `p` is a polynomial specified by the user. Its degree is the same +as `p`. + +~~~ python +class PolyEval(Gadget[F]): + ARITY = 1 + p: list[int] # polynomial coefficients + + def __init__(self, p: list[int]): + """ + Instantiate this gadget with the given polynomial. + """ + # Strip leading zeros. + for i in reversed(range(len(p))): + if p[i] != 0: + p = p[:i+1] + break + if len(p) < 1: + raise ValueError('invalid polynomial: zero length') + + self.p = p + self.DEGREE = len(p) - 1 + + def eval(self, field: type[F], inp: list[F]) -> F: + p = [field(coeff) for coeff in self.p] + return poly_eval(field, p, inp[0]) + + def eval_poly(self, + field: type[F], + inp_poly: list[list[F]]) -> list[F]: + p = [field(coeff) for coeff in self.p] + out = [field(0)] * (self.DEGREE * len(inp_poly[0])) + out[0] = p[0] + x = inp_poly[0] + for i in range(1, len(p)): + for j in range(len(x)): + out[j] += p[i] * x[j] + x = poly_mul(field, x, inp_poly[0]) + return poly_strip(field, out) +~~~ + +## Parallel Sum {#gadget-parallel-sum} + +The parallel-sum gadget applies a subcircuit to multiple inputs in parallel, +then returns the sum of the results. The arity of the gadget is equal to `count` +times the arity of the subcircuit, where `count` denotes the number of times the +subcircuit is called. The degree of the gadget is equal to the degree of the +subcircuit. + +~~~ python +class ParallelSum(Gadget[F]): + subcircuit: Gadget[F] + count: int + + def __init__(self, subcircuit: Gadget[F], count: int): + self.subcircuit = subcircuit + self.count = count + self.ARITY = subcircuit.ARITY * count + self.DEGREE = subcircuit.DEGREE + + def eval(self, field: type[F], inp: list[F]) -> F: + out = field(0) + for i in range(self.count): + start_index = i * self.subcircuit.ARITY + end_index = (i + 1) * self.subcircuit.ARITY + out += self.subcircuit.eval( + field, + inp[start_index:end_index], + ) + return out + + def eval_poly(self, + field: type[F], + inp_poly: list[list[F]]) -> list[F]: + output_poly_length = self.DEGREE * (len(inp_poly[0]) - 1) + 1 + out_sum = [field(0) for _ in range(output_poly_length)] + for i in range(self.count): + start_index = i * self.subcircuit.ARITY + end_index = (i + 1) * self.subcircuit.ARITY + out_current = self.subcircuit.eval_poly( + field, + inp_poly[start_index:end_index] + ) + for j in range(output_poly_length): + out_sum[j] += out_current[j] + return poly_strip(field, out_sum) +~~~ + +## Shims for Generating and Verifying Proofs {#gadget-wrappers} + +This section specifies two "shim" gadgets, one for generating an FLP as +specified in {{flp-bbcggi19-construction-prove}} and another for querying an FLP +as specified in {{flp-bbcggi19-construction-query}}. ~~~ python class ProveGadget(Gadget[F]): @@ -5977,7 +6127,6 @@ class ProveGadget(Gadget[F]): wrapped_valid.GADGETS = wrapped_gadgets return wrapped_valid - class QueryGadget(Gadget[F]): """ Gadget wrapper that records the input wires for each evaluation. diff --git a/poc/tests/test_flp_bbcggi19.py b/poc/tests/test_flp_bbcggi19.py index 6e01b605..fd100bab 100644 --- a/poc/tests/test_flp_bbcggi19.py +++ b/poc/tests/test_flp_bbcggi19.py @@ -2,8 +2,8 @@ from vdaf_poc.field import Field64, Field96, Field128, NttField from vdaf_poc.flp_bbcggi19 import (Count, FlpBBCGGI19, Histogram, Mul, - MultihotCountVec, PolyEval, Range2, Sum, - SumVec, Valid) + MultihotCountVec, PolyEval, Sum, SumVec, + Valid) from vdaf_poc.test_utils import TestFlpBBCGGI19 Measurement = TypeVar("Measurement") @@ -161,12 +161,12 @@ def test(self) -> None: class TestGadgets(TestFlpBBCGGI19): - def test_range2(self) -> None: - self.run_gadget_test(Range2(), Field128, 10) + def test_poly_eval_range2(self) -> None: + self.run_gadget_test(PolyEval([0, -1, 1]), Field128, 10) - def test_polyeval(self) -> None: + def test_poly_eval(self) -> None: self.run_gadget_test( - PolyEval([0, Field128.MODULUS - 23, 1, 3]), + PolyEval([0, -23, 1, 3]), Field128, 10, ) diff --git a/poc/vdaf_poc/field.py b/poc/vdaf_poc/field.py index ae0c18d0..cc6789ac 100644 --- a/poc/vdaf_poc/field.py +++ b/poc/vdaf_poc/field.py @@ -16,9 +16,10 @@ class Field: ENCODED_SIZE: int def __init__(self, val: int): - assert val >= 0 assert val < self.MODULUS - self.val = val + assert val > -self.MODULUS + # Interpret negative integers as additive inverses of field elements. + self.val = val % self.MODULUS @classmethod def zeros(cls, length: int) -> list[Self]: diff --git a/poc/vdaf_poc/flp_bbcggi19.py b/poc/vdaf_poc/flp_bbcggi19.py index 022b586d..8dc7ead6 100644 --- a/poc/vdaf_poc/flp_bbcggi19.py +++ b/poc/vdaf_poc/flp_bbcggi19.py @@ -475,69 +475,40 @@ def test_vec_set_type_param(self, test_vec: dict[str, Any]) -> list[str]: # GADGETS # +# NOTE: This class is excerpted in the document. Its width should be +# limited to 69 columns to avoid warnings from xml2rfc. +# =================================================================== class Mul(Gadget[F]): ARITY = 2 DEGREE = 2 - # NOTE: This method is excerpted in the document, de-indented. Its - # width should be limited to 69 columns after de-indenting, or 73 - # columns before de-indenting, to avoid warnings from xml2rfc. - # =================================================================== def eval(self, _field: type[F], inp: list[F]) -> F: - self.check_gadget_eval(inp) + self.check_gadget_eval(inp) # REMOVE ME return inp[0] * inp[1] - def eval_poly(self, field: type[F], inp_poly: list[list[F]]) -> list[F]: + def eval_poly(self, + field: type[F], + inp_poly: list[list[F]]) -> list[F]: self.check_gadget_eval_poly(inp_poly) # REMOVE ME return poly_mul(field, inp_poly[0], inp_poly[1]) -class Range2(Gadget[F]): - """ - Takes one input and computes x^2 - x. - """ - - ARITY = 1 - DEGREE = 2 - - # NOTE: This method is excerpted in the document, de-indented. Its - # width should be limited to 69 columns after de-indenting, or 73 - # columns before de-indenting, to avoid warnings from xml2rfc. - # =================================================================== - def eval(self, _field: type[F], inp: list[F]) -> F: - self.check_gadget_eval(inp) - return inp[0] * inp[0] - inp[0] - - def eval_poly(self, field: type[F], inp_poly: list[list[F]]) -> list[F]: - self.check_gadget_eval_poly(inp_poly) # REMOVE ME - output_poly_length = self.DEGREE * (len(inp_poly[0]) - 1) + 1 - out = [field(0) for _ in range(output_poly_length)] - x = inp_poly[0] - x_squared = poly_mul(field, x, x) - for (i, x_i) in enumerate(x): - out[i] -= x_i - for (i, x_squared_i) in enumerate(x_squared): - out[i] += x_squared_i - return poly_strip(field, out) - - +# NOTE: This class is excerpted in the document. Its width should be +# limited to 69 columns to avoid warnings from xml2rfc. +# =================================================================== class PolyEval(Gadget[F]): - # Polynomial coefficients. - p: list[int] - ARITY = 1 + p: list[int] # polynomial coefficients def __init__(self, p: list[int]): """ Instantiate this gadget with the given polynomial. """ - # Strip leading zeros. for i in reversed(range(len(p))): if p[i] != 0: p = p[:i+1] break - if len(p) < 1: raise ValueError('invalid polynomial: zero length') @@ -549,7 +520,9 @@ def eval(self, field: type[F], inp: list[F]) -> F: p = [field(coeff) for coeff in self.p] return poly_eval(field, p, inp[0]) - def eval_poly(self, field: type[F], inp_poly: list[list[F]]) -> list[F]: + def eval_poly(self, + field: type[F], + inp_poly: list[list[F]]) -> list[F]: self.check_gadget_eval_poly(inp_poly) # REMOVE ME p = [field(coeff) for coeff in self.p] out = [field(0)] * (self.DEGREE * len(inp_poly[0])) @@ -562,25 +535,29 @@ def eval_poly(self, field: type[F], inp_poly: list[list[F]]) -> list[F]: return poly_strip(field, out) +# NOTE: This class is excerpted in the document. Its width should be +# limited to 69 columns to avoid warnings from xml2rfc. +# =================================================================== class ParallelSum(Gadget[F]): + # REMOVE ME """ - Evaluates a subcircuit (represented by a Gadget) on multiple inputs, adds - the results, and returns the sum. - - The `count` parameter determines how many times the `subcircuit` gadget will - be called. The arity of this gadget is equal to the arity of the subcircuit - multiplied by the `count` parameter, and the degree of this gadget is equal - to the degree of the subcircuit. Input wires will be sequentially mapped to - input wires of the subcircuit instances. - - Section 4.4 of the BBCGGI19 paper outlines an optimization for circuits - fitting the parallel sum form, wherein a sum of n identical subcircuits can - be replaced with sqrt(n) parallel sum gadgets, each adding up sqrt(n) - subcircuit results. This results in smaller proofs, since the proof size - linearly depends on both the arity of gadgets and the number of times - gadgets are called. + Evaluates a subcircuit (represented by a Gadget) on multiple + inputs, adds the results, and returns the sum. + + The `count` parameter determines how many times the `subcircuit` + gadget will be called. The arity of this gadget is equal to the + arity of the subcircuit multiplied by the `count` parameter, and + the degree of this gadget is equal to the degree of the + subcircuit. Input wires will be sequentially mapped to input + wires of the subcircuit instances. + + Section 4.4 of the BBCGGI19 paper outlines an optimization for + circuits fitting the parallel sum form, wherein a sum of n + identical subcircuits can be replaced with sqrt(n) parallel sum + gadgets, each adding up sqrt(n) subcircuit results. This results + in smaller proofs, since the proof size linearly depends on both + the arity of gadgets and the number of times gadgets are called. """ - subcircuit: Gadget[F] count: int @@ -590,10 +567,6 @@ def __init__(self, subcircuit: Gadget[F], count: int): self.ARITY = subcircuit.ARITY * count self.DEGREE = subcircuit.DEGREE - # NOTE: This method is excerpted in the document, de-indented. Its - # width should be limited to 69 columns after de-indenting, or 73 - # columns before de-indenting, to avoid warnings from xml2rfc. - # =================================================================== def eval(self, field: type[F], inp: list[F]) -> F: self.check_gadget_eval(inp) # REMOVE ME out = field(0) @@ -606,7 +579,9 @@ def eval(self, field: type[F], inp: list[F]) -> F: ) return out - def eval_poly(self, field: type[F], inp_poly: list[list[F]]) -> list[F]: + def eval_poly(self, + field: type[F], + inp_poly: list[list[F]]) -> list[F]: self.check_gadget_eval_poly(inp_poly) # REMOVE ME output_poly_length = self.DEGREE * (len(inp_poly[0]) - 1) + 1 out_sum = [field(0) for _ in range(output_poly_length)] @@ -626,12 +601,10 @@ def eval_poly(self, field: type[F], inp_poly: list[list[F]]) -> list[F]: # TYPES # -class Count( - Valid[ - int, # Measurement, 0 or 1 - int, # AggResult - F, - ]): +# NOTE: This class is excerpted in the document. Its width should be +# limited to 69 columns to avoid warnings from xml2rfc. +# =================================================================== +class Count(Valid[int, int, F]): GADGETS: list[Gadget[F]] = [Mul()] GADGET_CALLS = [1] MEAS_LEN = 1 @@ -645,52 +618,48 @@ class Count( def __init__(self, field: type[F]): self.field = field - # NOTE: This method is excerpted in the document, de-indented. Its - # width should be limited to 69 columns after de-indenting, or 73 - # columns before de-indenting, to avoid warnings from xml2rfc. - # =================================================================== def eval( self, meas: list[F], joint_rand: list[F], _num_shares: int) -> list[F]: self.check_valid_eval(meas, joint_rand) # REMOVE ME - squared = self.GADGETS[0].eval(self.field, [meas[0], meas[0]]) + squared = self.GADGETS[0].eval(self.field, + [meas[0], meas[0]]) return [squared - meas[0]] def encode(self, measurement: int) -> list[F]: - if measurement not in [0, 1]: - raise ValueError('measurement out of range') + if measurement not in range(2): # REMOVE ME + raise ValueError('measurement out of range') # REMOVE ME return [self.field(measurement)] def truncate(self, meas: list[F]) -> list[F]: - if len(meas) != 1: - raise ValueError('incorrect encoded measurement length') + if len(meas) != 1: # REMOVE ME + raise ValueError('incorrect measurement length') # REMOVE ME return meas def decode(self, output: list[F], _num_measurements: int) -> int: return output[0].as_unsigned() -class Histogram( - Valid[ - int, # Measurement, `range(length)` - list[int], # AggResult - F, - ]): - # Class object for the field. +# NOTE: This class is excerpted in the document. Its width should be +# limited to 69 columns to avoid warnings from xml2rfc. +# =================================================================== +class Histogram(Valid[int, list[int], F]): + EVAL_OUTPUT_LEN = 2 field: type[F] length: int chunk_length: int - EVAL_OUTPUT_LEN = 2 - - def __init__(self, field: type[F], length: int, chunk_length: int): + def __init__(self, + field: type[F], + length: int, + chunk_length: int): """ - Instantiate an instance of the `Histogram` circuit with the given - length and chunk_length. + Instantiate an instance of the `Histogram` circuit with the + given `length` and `chunk_length`. """ - + # REMOVE ME if length <= 0: raise ValueError('invalid length') if chunk_length <= 0: @@ -699,18 +668,13 @@ def __init__(self, field: type[F], length: int, chunk_length: int): self.field = field self.length = length self.chunk_length = chunk_length - self.GADGETS: list[Gadget[F]] = [ - ParallelSum(Mul(), chunk_length), - ] - self.GADGET_CALLS = [(length + chunk_length - 1) // chunk_length] + self.GADGETS = [ParallelSum(Mul(), chunk_length)] + self.GADGET_CALLS = [ + (length + chunk_length - 1) // chunk_length] self.MEAS_LEN = self.length self.OUTPUT_LEN = self.length self.JOINT_RAND_LEN = self.GADGET_CALLS[0] - # NOTE: This method is excerpted in the document, de-indented. Its - # width should be limited to 69 columns after de-indenting, or 73 - # columns before de-indenting, to avoid warnings from xml2rfc. - # =================================================================== def eval( self, meas: list[F], @@ -750,11 +714,6 @@ def eval( return [range_check, sum_check] - # NOTE: The encode(), truncate(), and decode() methods are excerpted - # in the document, de-indented. Their width should be limited to 69 - # columns after de-indenting, or 73 columns before de-indenting, to - # avoid warnings from xml2rfc. - # =================================================================== def encode(self, measurement: int) -> list[F]: encoded = [self.field(0)] * self.length encoded[measurement] = self.field(1) @@ -767,7 +726,8 @@ def decode( self, output: list[F], _num_measurements: int) -> list[int]: - return [bucket_count.as_unsigned() for bucket_count in output] + return [bucket_count.as_unsigned() + for bucket_count in output] def test_vec_set_type_param(self, test_vec: dict[str, Any]) -> list[str]: test_vec['length'] = int(self.length) @@ -775,12 +735,11 @@ def test_vec_set_type_param(self, test_vec: dict[str, Any]) -> list[str]: return ['length', 'chunk_length'] -class MultihotCountVec( - Valid[ - list[int], # Measurement, a vector of bits - list[int], # AggResult, a vector of counts - F, - ]): +# NOTE: This class is excerpted in the document. Its width should be +# limited to 69 columns to avoid warnings from xml2rfc. +# =================================================================== +class MultihotCountVec(Valid[list[int], list[int], F]): + # REMOVE ME """ A validity circuit that checks each Client's measurement is a bit vector with at most `max_weight` number of 1s. We call the number @@ -805,10 +764,8 @@ class MultihotCountVec( `count_vec` is the count vector. The result is zero if and only if the reported weight is equal to the true weight. """ - # Class object for the field. - field: type[F] - EVAL_OUTPUT_LEN = 2 + field: type[F] def __init__(self, field: type[F], @@ -825,6 +782,7 @@ def __init__(self, - `0 < max_weight` and `max_weight <= length` - `chunk_length > 0` """ + # REMOVE ME if length <= 0: raise ValueError('invalid length') if max_weight <= 0 or max_weight > length: @@ -836,7 +794,8 @@ def __init__(self, # Compute the number of bits to represent `max_weight`. self.bits_for_weight = max_weight.bit_length() - self.offset = self.field((2 ** self.bits_for_weight) - 1 - max_weight) + self.offset = self.field( + 2**self.bits_for_weight - 1 - max_weight) # Make sure `offset + length` doesn't overflow the field # modulus. Otherwise we may not correctly compute the sum @@ -852,16 +811,13 @@ def __init__(self, ParallelSum(Mul(), chunk_length), ] self.GADGET_CALLS = [ - (length + self.bits_for_weight + chunk_length - 1) // chunk_length + (length + self.bits_for_weight + chunk_length - 1) + // chunk_length ] self.MEAS_LEN = self.length + self.bits_for_weight self.OUTPUT_LEN = self.length self.JOINT_RAND_LEN = self.GADGET_CALLS[0] - # NOTE: This method is excerpted in the document, de-indented. Its - # width should be limited to 69 columns after de-indenting, or 73 - # columns before de-indenting, to avoid warnings from xml2rfc. - # =================================================================== def eval( self, meas: list[F], @@ -905,11 +861,6 @@ def eval( return [range_check, weight_check] - # NOTE: The encode(), truncate(), and decode() methods are excerpted - # in the document, de-indented. Their width should be limited to 69 - # columns after de-indenting, or 73 columns before de-indenting, to - # avoid warnings from xml2rfc. - # =================================================================== def encode(self, measurement: list[int]) -> list[F]: if len(measurement) != self.length: raise ValueError('invalid Client measurement length') @@ -934,7 +885,8 @@ def decode( self, output: list[F], _num_measurements: int) -> list[int]: - return [bucket_count.as_unsigned() for bucket_count in output] + return [bucket_count.as_unsigned() for + bucket_count in output] def test_vec_set_type_param(self, test_vec: dict[str, Any]) -> list[str]: test_vec['length'] = int(self.length) @@ -943,25 +895,26 @@ def test_vec_set_type_param(self, test_vec: dict[str, Any]) -> list[str]: return ['length', 'max_weight', 'chunk_length'] -class SumVec( - Valid[ - list[int], # Measurement - list[int], # AggResult - F, - ]): +# NOTE: This class is excerpted in the document. Its width should be +# limited to 69 columns to avoid warnings from xml2rfc. +# =================================================================== +class SumVec(Valid[list[int], list[int], F]): + EVAL_OUTPUT_LEN = 1 length: int bits: int chunk_length: int - field: type[F] - EVAL_OUTPUT_LEN = 1 - def __init__(self, field: type[F], length: int, bits: int, chunk_length: int): + def __init__(self, + field: type[F], + length: int, + bits: int, + chunk_length: int): """ - Instantiate the `SumVec` circuit for measurements with `length` - elements, each in the range `[0, 2^bits)`. + Instantiate the `SumVec` circuit for measurements with + `length` elements, each in the range `[0, 2^bits)`. """ - + # REMOVE ME if 2 ** bits >= field.MODULUS: raise ValueError('bit size exceeds field modulus') if bits <= 0: @@ -975,7 +928,7 @@ def __init__(self, field: type[F], length: int, bits: int, chunk_length: int): self.length = length self.bits = bits self.chunk_length = chunk_length - self.GADGETS: list[Gadget[F]] = [ParallelSum(Mul(), chunk_length)] + self.GADGETS = [ParallelSum(Mul(), chunk_length)] self.GADGET_CALLS = [ (length * bits + chunk_length - 1) // chunk_length ] @@ -983,10 +936,6 @@ def __init__(self, field: type[F], length: int, bits: int, chunk_length: int): self.OUTPUT_LEN = length self.JOINT_RAND_LEN = self.GADGET_CALLS[0] - # NOTE: This method is excerpted in the document, de-indented. Its - # width should be limited to 69 columns after de-indenting, or 73 - # columns before de-indenting, to avoid warnings from xml2rfc. - # =================================================================== def eval( self, meas: list[F], @@ -1020,23 +969,21 @@ def eval( return [out] - # NOTE: The encode(), truncate(), and decode() methods are excerpted - # in the document, de-indented. Their width should be limited to 69 - # columns after de-indenting, or 73 columns before de-indenting, to - # avoid warnings from xml2rfc. - # =================================================================== def encode(self, measurement: list[int]) -> list[F]: + # REMOVE ME if len(measurement) != self.length: raise ValueError('incorrect measurement length') encoded = [] for val in measurement: + # REMOVE ME if val not in range(2**self.bits): raise ValueError( 'entry of measurement vector is out of range' ) - encoded += self.field.encode_into_bit_vector(val, self.bits) + encoded += self.field.encode_into_bit_vector( + val, self.bits) return encoded def truncate(self, meas: list[F]) -> list[F]: @@ -1060,18 +1007,18 @@ def test_vec_set_type_param(self, test_vec: dict[str, Any]) -> list[str]: return ['length', 'bits', 'chunk_length'] -class Sum( - Valid[ - int, # Measurement, `range(self.max_measurement + 1)` - int, # AggResult - F, - ]): - field: type[F] - GADGETS: list[Gadget[F]] = [Range2()] +# NOTE: This class is excerpted in the document, de-indented. Its +# width should be limited to 69 columns to avoid warnings from +# xml2rfc. +# =================================================================== +class Sum(Valid[int, int, F]): + GADGETS: list[Gadget[F]] = [PolyEval([0, -1, 1])] JOINT_RAND_LEN = 0 OUTPUT_LEN = 1 + field: type[F] def __init__(self, field: type[F], max_measurement: int): + # REMOVE ME """ A circuit that checks that the measurement is in range `[0, max_measurement]`. This is accomplished by encoding the @@ -1101,10 +1048,10 @@ def __init__(self, field: type[F], max_measurement: int): """ self.field = field self.bits = max_measurement.bit_length() - self.offset = self.field(2 ** self.bits - 1 - max_measurement) + self.offset = self.field(2**self.bits - 1 - max_measurement) self.max_measurement = max_measurement - if 2 ** self.bits >= self.field.MODULUS: - raise ValueError('bound exceeds field modulus') + if 2 ** self.bits >= self.field.MODULUS: # REMOVE ME + raise ValueError('bound exceeds field modulus') # REMOVE ME self.GADGET_CALLS = [2 * self.bits] self.MEAS_LEN = 2 * self.bits