From 75d9a95533781602bc61ff60ab205c11d31e7057 Mon Sep 17 00:00:00 2001 From: David Cook Date: Thu, 17 Oct 2024 13:20:46 -0500 Subject: [PATCH] Move encode() before eval() --- draft-irtf-cfrg-vdaf.md | 88 +++++++++++++-------------- poc/vdaf_poc/flp_bbcggi19.py | 112 +++++++++++++++++------------------ 2 files changed, 100 insertions(+), 100 deletions(-) diff --git a/draft-irtf-cfrg-vdaf.md b/draft-irtf-cfrg-vdaf.md index a1779151..7b7873bc 100644 --- a/draft-irtf-cfrg-vdaf.md +++ b/draft-irtf-cfrg-vdaf.md @@ -3770,6 +3770,9 @@ class Count(Valid[int, int, F]): def __init__(self, field: type[F]): self.field = field + def encode(self, measurement: int) -> list[F]: + return [self.field(measurement)] + def eval( self, meas: list[F], @@ -3779,9 +3782,6 @@ class Count(Valid[int, int, F]): [meas[0], meas[0]]) return [squared - meas[0]] - def encode(self, measurement: int) -> list[F]: - return [self.field(measurement)] - def truncate(self, meas: list[F]) -> list[F]: return meas @@ -3845,6 +3845,18 @@ class Sum(Valid[int, int, F]): self.MEAS_LEN = 2 * self.bits self.EVAL_OUTPUT_LEN = 2 * self.bits + 1 + def encode(self, measurement: int) -> list[F]: + encoded = [] + encoded += self.field.encode_into_bit_vec( + measurement, + self.bits + ) + encoded += self.field.encode_into_bit_vec( + measurement + self.offset.int(), + self.bits + ) + return encoded + def eval( self, meas: list[F], @@ -3862,18 +3874,6 @@ class Sum(Valid[int, int, F]): out.append(range_check) return out - def encode(self, measurement: int) -> list[F]: - encoded = [] - encoded += self.field.encode_into_bit_vec( - measurement, - self.bits - ) - encoded += self.field.encode_into_bit_vec( - measurement + self.offset.int(), - self.bits - ) - return encoded - def truncate(self, meas: list[F]) -> list[F]: return [self.field.decode_from_bit_vec(meas[:self.bits])] @@ -3957,6 +3957,13 @@ class SumVec(Valid[list[int], list[int], F]): self.OUTPUT_LEN = length self.JOINT_RAND_LEN = self.GADGET_CALLS[0] + def encode(self, measurement: list[int]) -> list[F]: + encoded = [] + for val in measurement: + encoded += self.field.encode_into_bit_vec( + val, self.bits) + return encoded + def eval( self, meas: list[F], @@ -3988,13 +3995,6 @@ class SumVec(Valid[list[int], list[int], F]): return [out] - def encode(self, measurement: list[int]) -> list[F]: - encoded = [] - for val in measurement: - encoded += self.field.encode_into_bit_vec( - val, self.bits) - return encoded - def truncate(self, meas: list[F]) -> list[F]: truncated = [] for i in range(self.length): @@ -4094,6 +4094,11 @@ class Histogram(Valid[int, list[int], F]): self.OUTPUT_LEN = self.length self.JOINT_RAND_LEN = self.GADGET_CALLS[0] + def encode(self, measurement: int) -> list[F]: + encoded = [self.field(0)] * self.length + encoded[measurement] = self.field(1) + return encoded + def eval( self, meas: list[F], @@ -4131,11 +4136,6 @@ class Histogram(Valid[int, list[int], F]): return [range_check, sum_check] - def encode(self, measurement: int) -> list[F]: - encoded = [self.field(0)] * self.length - encoded[measurement] = self.field(1) - return encoded - def truncate(self, meas: list[F]) -> list[F]: return meas @@ -4232,6 +4232,23 @@ class MultihotCountVec(Valid[list[int], list[int], F]): self.OUTPUT_LEN = self.length self.JOINT_RAND_LEN = self.GADGET_CALLS[0] + def encode(self, measurement: list[int]) -> list[F]: + if len(measurement) != self.length: + raise ValueError('invalid Client measurement length') + + # The first part is the vector of counters. + count_vec = list(map(self.field, measurement)) + + # The second part is the reported weight. + weight_reported = sum(count_vec, self.field(0)) + + encoded = [] + encoded += count_vec + encoded += self.field.encode_into_bit_vec( + (self.offset + weight_reported).int(), + self.bits_for_weight) + return encoded + def eval( self, meas: list[F], @@ -4273,23 +4290,6 @@ class MultihotCountVec(Valid[list[int], list[int], F]): return [range_check, weight_check] - def encode(self, measurement: list[int]) -> list[F]: - if len(measurement) != self.length: - raise ValueError('invalid Client measurement length') - - # The first part is the vector of counters. - count_vec = list(map(self.field, measurement)) - - # The second part is the reported weight. - weight_reported = sum(count_vec, self.field(0)) - - encoded = [] - encoded += count_vec - encoded += self.field.encode_into_bit_vec( - (self.offset + weight_reported).int(), - self.bits_for_weight) - return encoded - def truncate(self, meas: list[F]) -> list[F]: return meas[:self.length] diff --git a/poc/vdaf_poc/flp_bbcggi19.py b/poc/vdaf_poc/flp_bbcggi19.py index b2b975c6..0762020f 100644 --- a/poc/vdaf_poc/flp_bbcggi19.py +++ b/poc/vdaf_poc/flp_bbcggi19.py @@ -618,6 +618,11 @@ class Count(Valid[int, int, F]): def __init__(self, field: type[F]): self.field = field + def encode(self, measurement: int) -> list[F]: + if measurement not in range(2): # REMOVE ME + raise ValueError('measurement out of range') # REMOVE ME + return [self.field(measurement)] + def eval( self, meas: list[F], @@ -628,11 +633,6 @@ def eval( [meas[0], meas[0]]) return [squared - meas[0]] - def encode(self, measurement: int) -> list[F]: - if measurement not in range(2): # REMOVE ME - raise ValueError('measurement out of range') # REMOVE ME - return [self.field(measurement)] - def truncate(self, meas: list[F]) -> list[F]: if len(meas) != 1: # REMOVE ME raise ValueError('incorrect measurement length') # REMOVE ME @@ -675,6 +675,11 @@ def __init__(self, self.OUTPUT_LEN = self.length self.JOINT_RAND_LEN = self.GADGET_CALLS[0] + def encode(self, measurement: int) -> list[F]: + encoded = [self.field(0)] * self.length + encoded[measurement] = self.field(1) + return encoded + def eval( self, meas: list[F], @@ -714,11 +719,6 @@ def eval( return [range_check, sum_check] - def encode(self, measurement: int) -> list[F]: - encoded = [self.field(0)] * self.length - encoded[measurement] = self.field(1) - return encoded - def truncate(self, meas: list[F]) -> list[F]: return meas @@ -818,6 +818,23 @@ def __init__(self, self.OUTPUT_LEN = self.length self.JOINT_RAND_LEN = self.GADGET_CALLS[0] + def encode(self, measurement: list[int]) -> list[F]: + if len(measurement) != self.length: + raise ValueError('invalid Client measurement length') + + # The first part is the vector of counters. + count_vec = list(map(self.field, measurement)) + + # The second part is the reported weight. + weight_reported = sum(count_vec, self.field(0)) + + encoded = [] + encoded += count_vec + encoded += self.field.encode_into_bit_vec( + (self.offset + weight_reported).int(), + self.bits_for_weight) + return encoded + def eval( self, meas: list[F], @@ -861,23 +878,6 @@ def eval( return [range_check, weight_check] - def encode(self, measurement: list[int]) -> list[F]: - if len(measurement) != self.length: - raise ValueError('invalid Client measurement length') - - # The first part is the vector of counters. - count_vec = list(map(self.field, measurement)) - - # The second part is the reported weight. - weight_reported = sum(count_vec, self.field(0)) - - encoded = [] - encoded += count_vec - encoded += self.field.encode_into_bit_vec( - (self.offset + weight_reported).int(), - self.bits_for_weight) - return encoded - def truncate(self, meas: list[F]) -> list[F]: return meas[:self.length] @@ -936,6 +936,23 @@ def __init__(self, self.OUTPUT_LEN = length self.JOINT_RAND_LEN = self.GADGET_CALLS[0] + def encode(self, measurement: list[int]) -> list[F]: + # REMOVE ME + if len(measurement) != self.length: + raise ValueError('incorrect measurement length') + + encoded = [] + for val in measurement: + # REMOVE ME + if val not in range(2**self.bits): + raise ValueError( + 'entry of measurement vector is out of range' + ) + + encoded += self.field.encode_into_bit_vec( + val, self.bits) + return encoded + def eval( self, meas: list[F], @@ -969,23 +986,6 @@ def eval( return [out] - def encode(self, measurement: list[int]) -> list[F]: - # REMOVE ME - if len(measurement) != self.length: - raise ValueError('incorrect measurement length') - - encoded = [] - for val in measurement: - # REMOVE ME - if val not in range(2**self.bits): - raise ValueError( - 'entry of measurement vector is out of range' - ) - - encoded += self.field.encode_into_bit_vec( - val, self.bits) - return encoded - def truncate(self, meas: list[F]) -> list[F]: truncated = [] for i in range(self.length): @@ -1057,6 +1057,18 @@ def __init__(self, field: type[F], max_measurement: int): self.MEAS_LEN = 2 * self.bits self.EVAL_OUTPUT_LEN = 2 * self.bits + 1 + def encode(self, measurement: int) -> list[F]: + encoded = [] + encoded += self.field.encode_into_bit_vec( + measurement, + self.bits + ) + encoded += self.field.encode_into_bit_vec( + measurement + self.offset.int(), + self.bits + ) + return encoded + def eval( self, meas: list[F], @@ -1075,18 +1087,6 @@ def eval( out.append(range_check) return out - def encode(self, measurement: int) -> list[F]: - encoded = [] - encoded += self.field.encode_into_bit_vec( - measurement, - self.bits - ) - encoded += self.field.encode_into_bit_vec( - measurement + self.offset.int(), - self.bits - ) - return encoded - def truncate(self, meas: list[F]) -> list[F]: return [self.field.decode_from_bit_vec(meas[:self.bits])]