generated from martinthomson/internet-draft-template
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
107 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
''' | ||
Demonstration of FLP functionality over Shamir secret shares. | ||
''' | ||
|
||
from typing import TypeVar | ||
from unittest import TestCase | ||
|
||
from vdaf_poc.common import vec_add | ||
from vdaf_poc.field import Field, Field64, poly_interp | ||
from vdaf_poc.flp_bbcggi19 import Sum | ||
|
||
F = TypeVar("F", bound=Field) | ||
|
||
def shamir_shard(field: type[F], | ||
meas: list[F], | ||
unshard_threshold: int, | ||
num_shares: int) -> list[list[F]]: | ||
''' | ||
Compute Shamir's threshold secret sharing of `meas`. `num_shares` specifies | ||
the total of number of shares; `unshard_threshold` specifies the number of | ||
shares needed to unshard. | ||
''' | ||
# NOTE It would be more efficient to do "packed" Shamir by letting f(x) = | ||
# meas + rand_1 * x + rand_2 * x^2 + ..., and so on and letting the shares | ||
# be f(0), f(1), ..., and so on. That is: | ||
# | ||
# coeffs = [meas] | ||
# for _ in range(unshard_threshold): | ||
# coeffs.append(field.rand_vec(len(meas))) | ||
# | ||
# meas_shares = [] | ||
# for x in range(num_shares): | ||
# f = field.zeros(len(meas)) | ||
# for (i, c) in enumerate(coeffs): | ||
# for j in range(len(meas)): | ||
# f[j] += c[j] * field(x)**i | ||
# meas_shares.append(f) | ||
# | ||
# However we would need a version of `poly_interp()` that operates on | ||
# vectors over the field. | ||
meas_shares = [field.zeros(len(meas)) for _ in range(num_shares)] | ||
for i in range(len(meas)): | ||
coeffs = [meas[i]] + field.rand_vec(unshard_threshold-1) | ||
for j in range(num_shares): | ||
f = field(0) | ||
for (k, c) in enumerate(coeffs): | ||
x = field(j+1) | ||
f += c * x ** k | ||
meas_shares[j][i] = f | ||
return meas_shares | ||
|
||
def shamir_unshard(field: type[F], | ||
index: list[int], | ||
meas_shares: list[list[F]]) -> list[F]: | ||
''' | ||
Combine Shamir secret shares `meas_share` into the underlying secret. | ||
`index` is indicates the index of each secret share. That is, `index[I]` is | ||
equal to the index of `meas_share[I]` in the output of `shamir_shard()`. | ||
''' | ||
meas_len = len(meas_shares[0]) | ||
meas = [] | ||
for i in range(meas_len): | ||
xs = [] | ||
fs = [] | ||
for (j, share) in zip(index, meas_shares): | ||
x = field(j+1) | ||
f = share[i] | ||
xs.append(x) | ||
fs.append(f) | ||
coeffs = poly_interp(field, xs, fs) | ||
meas.append(coeffs[0]) | ||
return meas | ||
|
||
|
||
class TestShamir(TestCase): | ||
def test_shard_unshard(self) -> None: | ||
'''Test the basic functionality.''' | ||
meas = [Field64(13), Field64(37)] | ||
meas_shares = shamir_shard(Field64, meas, 2, 3) | ||
got = shamir_unshard(Field64, [0, 1], [meas_shares[0], meas_shares[1]]) | ||
self.assertEqual(got, meas) | ||
got = shamir_unshard(Field64, [0, 2], [meas_shares[0], meas_shares[2]]) | ||
self.assertEqual(got, meas) | ||
got = shamir_unshard(Field64, [1, 2], [meas_shares[1], meas_shares[2]]) | ||
self.assertEqual(got, meas) | ||
|
||
def test_aggregate(self) -> None: | ||
'''Show that Shamir secret sharing works with aggregation as well.''' | ||
unshard_threshold = 2 | ||
num_shares = 3 | ||
max_measurement = 100 | ||
v = Sum(Field64, max_measurement) | ||
agg_shares = [v.field.zeros(v.OUTPUT_LEN)] * num_shares | ||
for x in range(max_measurement): | ||
meas = v.encode(x) | ||
meas_shares = shamir_shard(v.field, | ||
meas, | ||
unshard_threshold, | ||
num_shares) | ||
for j in range(num_shares): | ||
agg_shares[j] = vec_add(agg_shares[j], v.truncate(meas_shares[j])) | ||
agg = shamir_unshard(v.field, [0,1], [agg_shares[0], agg_shares[1]]) | ||
agg_result = v.decode(agg, max_measurement) | ||
self.assertEqual(agg_result, max_measurement * (max_measurement-1) / 2) | ||
|
||
|
||
|