From 4710198818deff6753465b3f3e9539ba39c30f15 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Thu, 31 Oct 2024 12:42:57 -0700 Subject: [PATCH] WIP --- poc/example_flp_shamir.py | 107 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 poc/example_flp_shamir.py diff --git a/poc/example_flp_shamir.py b/poc/example_flp_shamir.py new file mode 100644 index 00000000..e4a8c825 --- /dev/null +++ b/poc/example_flp_shamir.py @@ -0,0 +1,107 @@ +''' +Demonstration of FLP functionality over Shamir secret shares. +''' + +from typing import TypeVar +from unittest import TestCase + +from vdaf_poc.common import vec_add +from vdaf_poc.field import Field, Field64, poly_interp +from vdaf_poc.flp_bbcggi19 import Sum + +F = TypeVar("F", bound=Field) + +def shamir_shard(field: type[F], + meas: list[F], + unshard_threshold: int, + num_shares: int) -> list[list[F]]: + ''' + Compute Shamir's threshold secret sharing of `meas`. `num_shares` specifies + the total of number of shares; `unshard_threshold` specifies the number of + shares needed to unshard. + ''' + # NOTE It would be more efficient to do "packed" Shamir by letting f(x) = + # meas + rand_1 * x + rand_2 * x^2 + ..., and so on and letting the shares + # be f(0), f(1), ..., and so on. That is: + # + # coeffs = [meas] + # for _ in range(unshard_threshold): + # coeffs.append(field.rand_vec(len(meas))) + # + # meas_shares = [] + # for x in range(num_shares): + # f = field.zeros(len(meas)) + # for (i, c) in enumerate(coeffs): + # for j in range(len(meas)): + # f[j] += c[j] * field(x)**i + # meas_shares.append(f) + # + # However we would need a version of `poly_interp()` that operates on + # vectors over the field. + meas_shares = [field.zeros(len(meas)) for _ in range(num_shares)] + for i in range(len(meas)): + coeffs = [meas[i]] + field.rand_vec(unshard_threshold-1) + for j in range(num_shares): + f = field(0) + for (k, c) in enumerate(coeffs): + x = field(j+1) + f += c * x ** k + meas_shares[j][i] = f + return meas_shares + +def shamir_unshard(field: type[F], + index: list[int], + meas_shares: list[list[F]]) -> list[F]: + ''' + Combine Shamir secret shares `meas_share` into the underlying secret. + `index` is indicates the index of each secret share. That is, `index[I]` is + equal to the index of `meas_share[I]` in the output of `shamir_shard()`. + ''' + meas_len = len(meas_shares[0]) + meas = [] + for i in range(meas_len): + xs = [] + fs = [] + for (j, share) in zip(index, meas_shares): + x = field(j+1) + f = share[i] + xs.append(x) + fs.append(f) + coeffs = poly_interp(field, xs, fs) + meas.append(coeffs[0]) + return meas + + +class TestShamir(TestCase): + def test_shard_unshard(self) -> None: + '''Test the basic functionality.''' + meas = [Field64(13), Field64(37)] + meas_shares = shamir_shard(Field64, meas, 2, 3) + got = shamir_unshard(Field64, [0, 1], [meas_shares[0], meas_shares[1]]) + self.assertEqual(got, meas) + got = shamir_unshard(Field64, [0, 2], [meas_shares[0], meas_shares[2]]) + self.assertEqual(got, meas) + got = shamir_unshard(Field64, [1, 2], [meas_shares[1], meas_shares[2]]) + self.assertEqual(got, meas) + + def test_aggregate(self) -> None: + '''Show that Shamir secret sharing works with aggregation as well.''' + unshard_threshold = 2 + num_shares = 3 + max_measurement = 100 + v = Sum(Field64, max_measurement) + agg_shares = [v.field.zeros(v.OUTPUT_LEN)] * num_shares + for x in range(max_measurement): + meas = v.encode(x) + meas_shares = shamir_shard(v.field, + meas, + unshard_threshold, + num_shares) + for j in range(num_shares): + agg_shares[j] = vec_add(agg_shares[j], v.truncate(meas_shares[j])) + agg = shamir_unshard(v.field, [0,1], [agg_shares[0], agg_shares[1]]) + agg_result = v.decode(agg, max_measurement) + self.assertEqual(agg_result, max_measurement * (max_measurement-1) / 2) + + +