Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a common base class to Daf and Vdaf #505

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions poc/vdaf_poc/daf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Definition of DAFs."""

from abc import ABCMeta, abstractmethod
from typing import Generic, TypeVar
from typing import Generic, TypeVar, override

from vdaf_poc.common import gen_rand

Expand All @@ -14,13 +14,66 @@
AggResult = TypeVar("AggResult")


class Daf(
class DistributedAggregation(
Generic[
Measurement, AggParam, PublicShare, InputShare, OutShare, AggShare,
AggResult
],
metaclass=ABCMeta):
"""
Abstract base class containing methods common to DAFs and VDAFs.
"""

@abstractmethod
def shard(self,
ctx: bytes,
measurement: Measurement,
nonce: bytes,
rand: bytes,
) -> tuple[PublicShare, list[InputShare]]:
pass

@abstractmethod
def is_valid(self,
agg_param: AggParam,
previous_agg_params: list[AggParam]) -> bool:
pass

@abstractmethod
def agg_init(self, agg_param: AggParam) -> AggShare:
pass

@abstractmethod
def agg_update(self,
agg_param: AggParam,
agg_share: AggShare,
out_share: OutShare) -> AggShare:
pass

@abstractmethod
def merge(self,
agg_param: AggParam,
agg_shares: list[AggShare]) -> AggShare:
pass

@abstractmethod
def unshard(self,
agg_param: AggParam,
agg_shares: list[AggShare],
num_measurements: int) -> AggResult:
pass


class Daf(
Generic[
Measurement, AggParam, PublicShare, InputShare, OutShare, AggShare,
AggResult
],
DistributedAggregation[
Measurement, AggParam, PublicShare, InputShare, OutShare, AggShare,
AggResult
]):
"""
A Distributed Aggregation Function (DAF).

Generic type parameters:
Expand Down Expand Up @@ -51,6 +104,7 @@ class Daf(
# Number of random bytes consumed by `shard()`.
RAND_SIZE: int

@override
cjpatton marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
def shard(
cjpatton marked this conversation as resolved.
Show resolved Hide resolved
self,
Expand All @@ -69,6 +123,7 @@ def shard(
"""
pass

@override
@abstractmethod
def is_valid(
self,
Expand Down Expand Up @@ -104,6 +159,7 @@ def prep(
"""
pass

@override
@abstractmethod
def agg_init(self,
agg_param: AggParam) -> AggShare:
Expand All @@ -112,6 +168,7 @@ def agg_init(self,
"""
pass

@override
@abstractmethod
def agg_update(self,
agg_param: AggParam,
Expand All @@ -123,6 +180,7 @@ def agg_update(self,
"""
pass

@override
@abstractmethod
def merge(self,
agg_param: AggParam,
Expand All @@ -132,6 +190,7 @@ def merge(self,
"""
pass

@override
@abstractmethod
def unshard(
self,
Expand Down
16 changes: 13 additions & 3 deletions poc/vdaf_poc/vdaf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Definition of VDAFs."""

from abc import ABCMeta, abstractmethod
from typing import Any, Generic, TypeVar
from abc import abstractmethod
from typing import Any, Generic, TypeVar, override

from vdaf_poc.common import format_dst, gen_rand
from vdaf_poc.daf import DistributedAggregation
from vdaf_poc.field import Field

Measurement = TypeVar("Measurement")
Expand All @@ -24,7 +25,10 @@ class Vdaf(
Measurement, AggParam, PublicShare, InputShare, OutShare, AggShare,
AggResult, PrepState, PrepShare, PrepMessage
],
metaclass=ABCMeta):
DistributedAggregation[
Measurement, AggParam, PublicShare, InputShare, OutShare, AggShare,
AggResult
]):
"""
A Verifiable Distributed Aggregation Function (VDAF).

Expand Down Expand Up @@ -65,6 +69,7 @@ class Vdaf(
# Name of the VDAF, for use in test vector filenames.
test_vec_name: str

@override
@abstractmethod
def shard(self,
ctx: bytes,
Expand All @@ -83,6 +88,7 @@ def shard(self,
"""
pass

@override
@abstractmethod
def is_valid(self, agg_param: AggParam,
previous_agg_params: list[AggParam]) -> bool:
Expand Down Expand Up @@ -142,6 +148,7 @@ def prep_shares_to_prep(self,
"""
pass

@override
@abstractmethod
def agg_init(self,
agg_param: AggParam) -> AggShare:
Expand All @@ -150,6 +157,7 @@ def agg_init(self,
"""
pass

@override
@abstractmethod
def agg_update(self,
agg_param: AggParam,
Expand All @@ -161,6 +169,7 @@ def agg_update(self,
"""
pass

@override
@abstractmethod
def merge(self,
agg_param: AggParam,
Expand All @@ -170,6 +179,7 @@ def merge(self,
"""
pass

@override
@abstractmethod
def unshard(self,
agg_param: AggParam,
Expand Down