From 83d59af7dc917dc4cba039cb484a62721aa81c51 Mon Sep 17 00:00:00 2001 From: David Cook Date: Wed, 16 Oct 2024 12:16:19 -0500 Subject: [PATCH] Add a common base class to Daf and Vdaf --- poc/vdaf_poc/daf.py | 63 ++++++++++++++++++++++++++++++++++++++++++-- poc/vdaf_poc/vdaf.py | 16 ++++++++--- 2 files changed, 74 insertions(+), 5 deletions(-) diff --git a/poc/vdaf_poc/daf.py b/poc/vdaf_poc/daf.py index f05c2b61..5019b8b4 100644 --- a/poc/vdaf_poc/daf.py +++ b/poc/vdaf_poc/daf.py @@ -1,7 +1,7 @@ """Definition of DAFs.""" from abc import ABCMeta, abstractmethod -from typing import Generic, TypeVar +from typing import Generic, TypeVar, override from vdaf_poc.common import gen_rand @@ -14,13 +14,66 @@ AggResult = TypeVar("AggResult") -class Daf( +class DistributedAggregation( Generic[ Measurement, AggParam, PublicShare, InputShare, OutShare, AggShare, AggResult ], metaclass=ABCMeta): """ + Abstract base class containing methods common to DAFs and VDAFs. + """ + + @abstractmethod + def shard(self, + ctx: bytes, + measurement: Measurement, + nonce: bytes, + rand: bytes, + ) -> tuple[PublicShare, list[InputShare]]: + pass + + @abstractmethod + def is_valid(self, + agg_param: AggParam, + previous_agg_params: list[AggParam]) -> bool: + pass + + @abstractmethod + def agg_init(self, agg_param: AggParam) -> AggShare: + pass + + @abstractmethod + def agg_update(self, + agg_param: AggParam, + agg_share: AggShare, + out_share: OutShare) -> AggShare: + pass + + @abstractmethod + def merge(self, + agg_param: AggParam, + agg_shares: list[AggShare]) -> AggShare: + pass + + @abstractmethod + def unshard(self, + agg_param: AggParam, + agg_shares: list[AggShare], + num_measurements: int) -> AggResult: + pass + + +class Daf( + Generic[ + Measurement, AggParam, PublicShare, InputShare, OutShare, AggShare, + AggResult + ], + DistributedAggregation[ + Measurement, AggParam, PublicShare, InputShare, OutShare, AggShare, + AggResult + ]): + """ A Distributed Aggregation Function (DAF). Generic type parameters: @@ -51,6 +104,7 @@ class Daf( # Number of random bytes consumed by `shard()`. RAND_SIZE: int + @override @abstractmethod def shard( self, @@ -69,6 +123,7 @@ def shard( """ pass + @override @abstractmethod def is_valid( self, @@ -104,6 +159,7 @@ def prep( """ pass + @override @abstractmethod def agg_init(self, agg_param: AggParam) -> AggShare: @@ -112,6 +168,7 @@ def agg_init(self, """ pass + @override @abstractmethod def agg_update(self, agg_param: AggParam, @@ -123,6 +180,7 @@ def agg_update(self, """ pass + @override @abstractmethod def merge(self, agg_param: AggParam, @@ -132,6 +190,7 @@ def merge(self, """ pass + @override @abstractmethod def unshard( self, diff --git a/poc/vdaf_poc/vdaf.py b/poc/vdaf_poc/vdaf.py index 4f841dd2..f2d26247 100644 --- a/poc/vdaf_poc/vdaf.py +++ b/poc/vdaf_poc/vdaf.py @@ -1,9 +1,10 @@ """Definition of VDAFs.""" -from abc import ABCMeta, abstractmethod -from typing import Any, Generic, TypeVar +from abc import abstractmethod +from typing import Any, Generic, TypeVar, override from vdaf_poc.common import format_dst, gen_rand +from vdaf_poc.daf import DistributedAggregation from vdaf_poc.field import Field Measurement = TypeVar("Measurement") @@ -24,7 +25,10 @@ class Vdaf( Measurement, AggParam, PublicShare, InputShare, OutShare, AggShare, AggResult, PrepState, PrepShare, PrepMessage ], - metaclass=ABCMeta): + DistributedAggregation[ + Measurement, AggParam, PublicShare, InputShare, OutShare, AggShare, + AggResult + ]): """ A Verifiable Distributed Aggregation Function (VDAF). @@ -65,6 +69,7 @@ class Vdaf( # Name of the VDAF, for use in test vector filenames. test_vec_name: str + @override @abstractmethod def shard(self, ctx: bytes, @@ -83,6 +88,7 @@ def shard(self, """ pass + @override @abstractmethod def is_valid(self, agg_param: AggParam, previous_agg_params: list[AggParam]) -> bool: @@ -142,6 +148,7 @@ def prep_shares_to_prep(self, """ pass + @override @abstractmethod def agg_init(self, agg_param: AggParam) -> AggShare: @@ -150,6 +157,7 @@ def agg_init(self, """ pass + @override @abstractmethod def agg_update(self, agg_param: AggParam, @@ -161,6 +169,7 @@ def agg_update(self, """ pass + @override @abstractmethod def merge(self, agg_param: AggParam, @@ -170,6 +179,7 @@ def merge(self, """ pass + @override @abstractmethod def unshard(self, agg_param: AggParam,