Skip to content

Commit

Permalink
Encode integer as bit vector representation
Browse files Browse the repository at this point in the history
Add class methods on "Field" to encode integer as bit vector representation,
as a vector of field elements, and vice versa, because it is a common
functionality used by both Prio3Sum and Prio3SumVec, and increases readability.
The class methods are aligned with libprio-rs implementations:
https://github.com/divviup/libprio-rs/blob/release/0.14/src/field.rs#L232-L267
  • Loading branch information
junyechen1996 authored and cjpatton committed Sep 13, 2023
1 parent d144076 commit 510c00c
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 33 deletions.
63 changes: 46 additions & 17 deletions draft-irtf-cfrg-vdaf.md
Original file line number Diff line number Diff line change
Expand Up @@ -1646,7 +1646,7 @@ unsigned integer into a field element:
* `Field(integer: Unsigned)` returns `integer` represented as a field element.
The value of `integer` MUST be less than `Field.MODULUS`.

Finally, each concrete `Field` has two derived class methods, one for encoding
Each concrete `Field` has two derived class methods, one for encoding
a vector of field elements as a byte string and another for decoding a vector of
field elements.

Expand All @@ -1673,6 +1673,43 @@ def decode_vec(Field, encoded: Bytes) -> Vec[Field]:
~~~
{: #field-derived-methods title="Derived class methods for finite fields."}

Finally, `Field` implements the following methods for representing a value as
a sequence of field elements, each of which represents a bit of the input.

~~~
def encode_into_bit_vector(Field,
val: Unsigned,
bits: Unsigned) -> Vec[Field]:
"""
Encode the bit representation of `val` with at most `bits` number
of bits, as a vector of field elements.
"""
if val >= 2 ** bits:
# Sanity check we are able to represent `val` with `bits`
# number of bits.
raise ValueError("Number of bits is not enough to represent "
"the input integer.")
encoded = []
for l in range(bits):
encoded.append(Field((val >> l) & 1))
return encoded

def decode_from_bit_vector(Field, vec: Vec[Field]) -> Field:
"""
Decode the field element from the bit representation, expressed
as a vector of field elements `vec`.
"""
bits = len(vec)
if Field.MODULUS >> bits == 0:
raise ValueError("Number of bits is too large to be "
"represented by field modulus.")
decoded = Field(0)
for (l, bit) in enumerate(vec):
decoded += Field(1 << l) * bit
return decoded
~~~
{: #field-bit-rep title="Derived class methods to encode integers into bit vector representation."}

### Auxiliary Functions

The following auxiliary functions on vectors of field elements are used in the
Expand Down Expand Up @@ -3127,17 +3164,11 @@ def encode(self, measurement):
if 0 > measurement or measurement >= 2 ** self.MEAS_LEN:
raise ERR_INPUT

encoded = []
for l in range(self.MEAS_LEN):
encoded.append(self.Field((measurement >> l) & 1))
return encoded
return self.Field.encode_into_bit_vector(measurement,
self.MEAS_LEN)

def truncate(self, meas):
decoded = self.Field(0)
for (l, b) in enumerate(meas):
w = self.Field(1 << l)
decoded += w * b
return [decoded]
return [self.Field.decode_from_bit_vector(meas)]

def decode(self, output, _num_measurements):
return output[0].as_unsigned()
Expand Down Expand Up @@ -3203,17 +3234,15 @@ def encode(self, measurement: Vec[Unsigned]):
if 0 > val or val >= 2 ** self.bits:
raise ERR_INPUT

for l in range(self.bits):
encoded.append(self.Field((val >> l) & 1))

encoded += self.Field.encode_into_bit_vector(val, self.bits)
return encoded

def truncate(self, meas):
truncated = [self.Field(0) for _ in range(self.length)]
truncated = []
for i in range(self.length):
for j in range(self.bits):
weight = self.Field(1 << j)
truncated[i] += weight * meas[i * self.bits + j]
truncated.append(self.Field.decode_from_bit_vector(
meas[i * self.bits: (i + 1) * self.bits]
))
return truncated

def decode(self, output, _num_measurements):
Expand Down
40 changes: 40 additions & 0 deletions poc/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,39 @@ def decode_vec(Field, encoded: Bytes) -> Vec[Field]:
vec.append(Field(x))
return vec

@classmethod
def encode_into_bit_vector(Field,
val: Unsigned,
bits: Unsigned) -> Vec[Field]:
"""
Encode the bit representation of `val` with at most `bits` number
of bits, as a vector of field elements.
"""
if val >= 2 ** bits:
# Sanity check we are able to represent `val` with `bits`
# number of bits.
raise ValueError("Number of bits is not enough to represent "
"the input integer.")
encoded = []
for l in range(bits):
encoded.append(Field((val >> l) & 1))
return encoded

@classmethod
def decode_from_bit_vector(Field, vec: Vec[Field]) -> Field:
"""
Decode the field element from the bit representation, expressed
as a vector of field elements `vec`.
"""
bits = len(vec)
if Field.MODULUS >> bits == 0:
raise ValueError("Number of bits is too large to be "
"represented by field modulus.")
decoded = Field(0)
for (l, bit) in enumerate(vec):
decoded += Field(1 << l) * bit
return decoded

def __add__(self, other: Field) -> Field:
return self.__class__(self.val + other.val)

Expand Down Expand Up @@ -252,6 +285,13 @@ def test_field(cls):
got = cls.decode_vec(cls.encode_vec(want))
assert got == want

# Test encoding integer as bit vector.
vals = [i for i in range(15)]
bits = 4
for val in vals:
encoded = cls.encode_into_bit_vector(val, bits)
assert cls.decode_from_bit_vector(encoded).as_unsigned() == val


def test_fft_field(cls):
test_field(cls)
Expand Down
33 changes: 17 additions & 16 deletions poc/flp_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,17 +541,11 @@ def encode(self, measurement):
if 0 > measurement or measurement >= 2 ** self.MEAS_LEN:
raise ERR_INPUT

encoded = []
for l in range(self.MEAS_LEN):
encoded.append(self.Field((measurement >> l) & 1))
return encoded
return self.Field.encode_into_bit_vector(measurement,
self.MEAS_LEN)

def truncate(self, meas):
decoded = self.Field(0)
for (l, b) in enumerate(meas):
w = self.Field(1 << l)
decoded += w * b
return [decoded]
return [self.Field.decode_from_bit_vector(meas)]

def decode(self, output, _num_measurements):
return output[0].as_unsigned()
Expand Down Expand Up @@ -724,17 +718,15 @@ def encode(self, measurement: Vec[Unsigned]):
if val < 0 or val >= 2 ** self.bits:
raise ERR_INPUT

for l in range(self.bits):
encoded.append(self.Field((val >> l) & 1))

encoded += self.Field.encode_into_bit_vector(val, self.bits)
return encoded

def truncate(self, meas):
truncated = [self.Field(0) for _ in range(self.length)]
truncated = []
for i in range(self.length):
for j in range(self.bits):
weight = self.Field(1 << j)
truncated[i] += weight * meas[i * self.bits + j]
truncated.append(self.Field.decode_from_bit_vector(
meas[i * self.bits: (i + 1) * self.bits]
))
return truncated

def decode(self, output, _num_measurements):
Expand Down Expand Up @@ -856,6 +848,9 @@ def test():
(flp.encode(2 ** 10 - 1), True),
(flp.Field.rand_vec(10), False),
])
# Roundtrip test with no proof generated.
for meas in [0, 100, 2 ** 10 - 1]:
assert meas == flp.decode(flp.truncate(flp.encode(meas)), 1)

flp = FlpGeneric(Histogram(4, 2))
test_flp_generic(flp, [
Expand All @@ -868,6 +863,12 @@ def test():
(flp.Field.rand_vec(4), False),
])

# SumVec with length 2, bits 4, chunk len 1.
flp = FlpGeneric(SumVec(2, 4, 1))
# Roundtrip test with no proof generated.
for meas in [[1, 2], [3, 4], [5, 6], [7, 8]]:
assert meas == flp.decode(flp.truncate(flp.encode(meas)), 1)

flp = FlpGeneric(TestMultiGadget())
test_flp_generic(flp, [
(flp.encode(0), True),
Expand Down

0 comments on commit 510c00c

Please sign in to comment.