Skip to content

Commit

Permalink
Replace SHAKE128 with TurboSHAKE128
Browse files Browse the repository at this point in the history
The reference code uses the reference implementation of TurboSHAKE128.
This code is unoptimized, so care is needed to ensure our tests run in a
reasonable amount of time.

Each time `XofTurboShake128` is constructed we call `TurboSHAKE128()`
once and fill a buffer with the output stream. The size of the buffer is
a constant, `MAX_XOF_OUT_STREAM_BYTES`, chosen to be sufficiently long
for every test that we have. So that we don't have to make this value
too large, some of tests in `vdaf_poplar1.py` have been modified.
  • Loading branch information
cjpatton committed Nov 9, 2023
1 parent ccbb6a0 commit 450314c
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 56 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "poc/draft-irtf-cfrg-kangarootwelve"]
path = poc/draft-irtf-cfrg-kangarootwelve
url = https://github.com/cfrg/draft-irtf-cfrg-kangarootwelve
12 changes: 6 additions & 6 deletions poc/daf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import field
from common import Bool, Unsigned, gen_rand
from xof import XofShake128
from xof import XofTurboShake128


class Daf:
Expand Down Expand Up @@ -166,11 +166,11 @@ class TestDaf(Daf):

@classmethod
def shard(cls, measurement, _nonce, rand):
helper_shares = XofShake128.expand_into_vec(cls.Field,
rand,
b'',
b'',
cls.SHARES-1)
helper_shares = XofTurboShake128.expand_into_vec(cls.Field,
rand,
b'',
b'',
cls.SHARES-1)
leader_share = cls.Field(measurement)
for helper_share in helper_shares:
leader_share -= helper_share
Expand Down
1 change: 1 addition & 0 deletions poc/draft-irtf-cfrg-kangarootwelve
22 changes: 11 additions & 11 deletions poc/vdaf_poplar1.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def with_bits(Poplar1, bits: Unsigned):
TheIdpf = idpf_poplar.IdpfPoplar \
.with_value_len(2) \
.with_bits(bits)
TheXof = xof.XofShake128
TheXof = xof.XofTurboShake128

class Poplar1WithBits(Poplar1):
Idpf = TheIdpf
Expand Down Expand Up @@ -414,28 +414,28 @@ def encode_idpf_field_vec(vec):
[2],
)
test_vdaf(
Poplar1.with_bits(128),
Poplar1.with_bits(64),
(
127,
(from_be_bytes(b'0123456789abcdef'),),
63,
(from_be_bytes(b'01234567'),),
),
[
from_be_bytes(b'0123456789abcdef'),
from_be_bytes(b'01234567'),
],
[1],
)
test_vdaf(
Poplar1.with_bits(256),
Poplar1.with_bits(64),
(
63,
31,
(
from_be_bytes(b'00000000'),
from_be_bytes(b'01234567'),
from_be_bytes(b'0000'),
from_be_bytes(b'0123'),
),
),
[
from_be_bytes(b'0123456789abcdef0123456789abcdef'),
from_be_bytes(b'01234567890000000000000000000000'),
from_be_bytes(b'01234567'),
from_be_bytes(b'01234000'),
],
[0, 2],
)
Expand Down
28 changes: 14 additions & 14 deletions poc/vdaf_prio3.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,23 +443,23 @@ def test_vec_encode_prep_msg(Prio3, k_joint_rand):

class Prio3Count(Prio3):
# Generic types required by `Prio3`
Xof = xof.XofShake128
Xof = xof.XofTurboShake128
Flp = flp_generic.FlpGeneric(flp_generic.Count())

# Associated parameters.
ID = 0x00000000
VERIFY_KEY_SIZE = xof.XofShake128.SEED_SIZE
VERIFY_KEY_SIZE = xof.XofTurboShake128.SEED_SIZE

# Operational parameters.
test_vec_name = 'Prio3Count'


class Prio3Sum(Prio3):
# Generic types required by `Prio3`
Xof = xof.XofShake128
Xof = xof.XofTurboShake128

# Associated parameters.
VERIFY_KEY_SIZE = xof.XofShake128.SEED_SIZE
VERIFY_KEY_SIZE = xof.XofTurboShake128.SEED_SIZE
ID = 0x00000001

# Operational parameters.
Expand All @@ -474,10 +474,10 @@ class Prio3SumWithBits(Prio3Sum):

class Prio3SumVec(Prio3):
# Generic types required by `Prio3`
Xof = xof.XofShake128
Xof = xof.XofTurboShake128

# Associated parameters.
VERIFY_KEY_SIZE = xof.XofShake128.SEED_SIZE
VERIFY_KEY_SIZE = xof.XofTurboShake128.SEED_SIZE
ID = 0x00000002

# Operational parameters.
Expand All @@ -495,10 +495,10 @@ class Prio3SumVecWithParams(Prio3SumVec):

class Prio3Histogram(Prio3):
# Generic types required by `Prio3`
Xof = xof.XofShake128
Xof = xof.XofTurboShake128

# Associated parameters.
VERIFY_KEY_SIZE = xof.XofShake128.SEED_SIZE
VERIFY_KEY_SIZE = xof.XofTurboShake128.SEED_SIZE
ID = 0x00000003

# Operational parameters.
Expand Down Expand Up @@ -552,10 +552,10 @@ class Prio3SumVecWithMultiproofAndParams(cls):

class Prio3MultiHotHistogram(Prio3):
# Generic types required by `Prio3`
Xof = xof.XofShake128
Xof = xof.XofTurboShake128

# Associated parameters.
VERIFY_KEY_SIZE = xof.XofShake128.SEED_SIZE
VERIFY_KEY_SIZE = xof.XofTurboShake128.SEED_SIZE
# Private codepoint just for testing.
ID = 0xFFFFFFFF

Expand Down Expand Up @@ -584,11 +584,11 @@ class TestPrio3Average(Prio3):
class's decode() method.
"""

Xof = xof.XofShake128
Xof = xof.XofTurboShake128
# NOTE 0xFFFFFFFF is reserved for testing. If we decide to standardize this
# Prio3 variant, then we'll need to pick a real codepoint for it.
ID = 0xFFFFFFFF
VERIFY_KEY_SIZE = xof.XofShake128.SEED_SIZE
VERIFY_KEY_SIZE = xof.XofTurboShake128.SEED_SIZE

@classmethod
def with_bits(cls, bits: Unsigned):
Expand Down Expand Up @@ -650,7 +650,7 @@ def test_prio3sumvec_with_multiproof():
num_shares = 2 # Must be in range `[2, 256)`

cls = Prio3 \
.with_xof(xof.XofShake128) \
.with_xof(xof.XofTurboShake128) \
.with_flp(flp.FlpTestField128()) \
.with_shares(num_shares)
cls.ID = 0xFFFFFFFF
Expand All @@ -659,7 +659,7 @@ def test_prio3sumvec_with_multiproof():
# If JOINT_RAND_LEN == 0, then Fiat-Shamir isn't needed and we can skip
# generating the joint randomness.
cls = Prio3 \
.with_xof(xof.XofShake128) \
.with_xof(xof.XofTurboShake128) \
.with_flp(flp.FlpTestField128.with_joint_rand_len(0)) \
.with_shares(num_shares)
cls.ID = 0xFFFFFFFF
Expand Down
52 changes: 27 additions & 25 deletions poc/xof.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@

from __future__ import annotations

import sys
sys.path.append('draft-irtf-cfrg-kangarootwelve/py')

from Cryptodome.Cipher import AES
from Cryptodome.Hash import SHAKE128
from TurboSHAKE import TurboSHAKE128

from common import (TEST_VECTOR, VERSION, Bytes, Unsigned, concat, format_dst,
from_le_bytes, gen_rand, next_power_of_2,
print_wrapped_line, to_le_bytes, xor)

# XXX Not a constant in the spec, need to use TurboSHAKE128 efficiently.
MAX_XOF_OUT_STREAM_BYTES = 2000

class Xof:
"""The base class for XOFs."""
Expand Down Expand Up @@ -61,27 +66,28 @@ def expand_into_vec(Xof,
return xof.next_vec(Field, length)


class XofShake128(Xof):
class XofTurboShake128(Xof):
"""XOF based on SHA-3 (SHAKE128)."""

# Associated parameters
SEED_SIZE = 16

# Operational parameters.
test_vec_name = 'XofShake128'
test_vec_name = 'XofTurboShake128'

def __init__(self, seed, dst, binder):
# The input is composed of `dst`, the domain separation tag, the
# `seed`, and the `binder` string.
self.shake = SHAKE128.new()
dst_length = to_le_bytes(len(dst), 1)
self.shake.update(dst_length)
self.shake.update(dst)
self.shake.update(seed)
self.shake.update(binder)
self.length_consumed = 0
self.out_stream = TurboSHAKE128(
to_le_bytes(len(dst), 1) + dst + seed + binder,
1,
MAX_XOF_OUT_STREAM_BYTES,
)

def next(self, length: Unsigned) -> Bytes:
return self.shake.read(length)
def next(self, length):
assert self.length_consumed + length < MAX_XOF_OUT_STREAM_BYTES
out = self.out_stream[self.length_consumed:self.length_consumed+length]
self.length_consumed += length
return out


class XofFixedKeyAes128(Xof):
Expand All @@ -106,12 +112,8 @@ def __init__(self, seed, dst, binder):
#
# Implementation note: This step can be cached across XOF
# evaluations with many different seeds.
shake = SHAKE128.new()
dst_length = to_le_bytes(len(dst), 1)
shake.update(dst_length)
shake.update(dst)
shake.update(binder)
fixed_key = shake.read(16)
fixed_key = TurboSHAKE128(
to_le_bytes(len(dst), 1) + dst + binder, 2, 16)
self.cipher = AES.new(fixed_key, AES.MODE_ECB)
# Save seed to be used in `next`.
self.seed = seed
Expand Down Expand Up @@ -182,17 +184,17 @@ def test_xof(Xof, F, expanded_len):

# This test case was found through brute-force search using this tool:
# https://github.com/divergentdave/vdaf-rejection-sampling-search
expanded_vec = XofShake128.expand_into_vec(
expanded_vec = XofTurboShake128.expand_into_vec(
Field64,
bytes([0x29, 0xb2, 0x98, 0x64, 0xb4, 0xaa, 0x4e, 0x07, 0x2a, 0x44,
0x49, 0x24, 0xf6, 0x74, 0x0a, 0x3d]),
bytes([0xd1, 0x95, 0xec, 0x90, 0xc1, 0xbc, 0xf1, 0xf2, 0xcb, 0x2c,
0x7e, 0x74, 0xc5, 0xc5, 0xf6, 0xda]),
b'', # domain separation tag
b'', # binder
33237,
140,
)
assert expanded_vec[-1] == Field64(2035552711764301796)
assert expanded_vec[-1] == Field64(9734340616212735019)

for cls in (XofShake128, XofFixedKeyAes128):
for cls in (XofTurboShake128, XofFixedKeyAes128):
test_xof(cls, Field128, 23)

if TEST_VECTOR:
Expand Down

0 comments on commit 450314c

Please sign in to comment.