Skip to content

Commit

Permalink
poc: Reduce test runtime by making TurboSHAKE128 stateful
Browse files Browse the repository at this point in the history
The reference implementation of TurboSHAKE128 provides a one-shot API.
To use this efficiently, we restrict the number of bytes required for
tests so that we can pre-compute the entire output buffer. This is
controlled by MAX_XOF_OUT_STREAM_BYTES.

This restriction is confusing and has side-effects for specs that use
XofTurboShake128. However using the one-shot API without this
optimization makes the unit tests prohibitively slow, about a minute on
my machine.

Implement a stateful API for TurboSHAKE128 and use it in
XofTurboShake128. This reduces the runtime to 20 seconds on my machine.

Accordingly, restore the unit tests for Poplar1 to what they were before
we made this optimization.
  • Loading branch information
cjpatton committed Jan 10, 2024
1 parent 3d2dd77 commit 97fcdd4
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 36 deletions.
1 change: 1 addition & 0 deletions poc/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ test:
sage -python common.py
sage -python field.py
sage -python xof.py
sage -python turboshake.py
sage -python flp.py
sage -python flp_generic.py
sage -python idpf.py
Expand Down
189 changes: 189 additions & 0 deletions poc/turboshake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# A stateful implementation of TurboSHAKE adapted from the reference implementation
#
# We use TurboSHAKE in two steps:
#
# 1. Message fragments are absorbed into the hash state
# 2. Output fragments are squeezed out of the hash state
#
# The reference implementation of TurboSHAKE only provides a "one-shot" API,
# where the message and the length of the output are determined in advance.
#
# The stateful API is not needed if you know the desired output length in
# advance. Even if you don't know the desired output length, you can always do
# something like this:
#
# 1. Concatenate the message fragments into message `M`
# 2. Keep track of the output length `totalOutputBytesLen` squeezed so far and
# output `TurboSHAKE(c, M, D, totalOutputBytesLen+nextOutputBytesLen)`.
#
# However if the output length is large, then this is prohibitively slow, even
# for reference code. In particular, this makes the unit tests for Prio3 and
# Poplar1 take well over 30 seconds to run. Thus the purpose of implementing a
# stateful API is to make our unit tests run in a reasonable amount of time.

import os
import sys

kangarootwelve_path = \
"%s/draft-irtf-cfrg-kangarootwelve/py" % os.path.dirname(__file__) # nopep8
assert os.path.isdir(kangarootwelve_path) # nopep8
sys.path.append(kangarootwelve_path) # nopep8

from TurboSHAKE import KeccakP1600, TurboSHAKE128


class TurboSHAKEAbosrb:
'''TurboSHAKE in the absorb state.'''

def __init__(self, c, D):
'''
Initialize the absorb state with capacity `c` (number of bits) and
domain separation byte `D`.
'''
self.D = D
self.rate_in_bytes = (1600-c)//8
self.state = bytearray([0 for i in range(200)])
self.state_offset = 0

def update(self, M):
'''
Update the absorb state with message fragment `M`.
'''
input_offset = 0
while input_offset < len(M):
length = len(M)-input_offset
block_size = min(length, self.rate_in_bytes-self.state_offset)
for i in range(block_size):
self.state[i+self.state_offset] ^= M[i+input_offset]
input_offset += block_size
self.state_offset += block_size
if self.state_offset == self.rate_in_bytes:
self.state = KeccakP1600(self.state, 12)
self.state_offset = 0

def squeeze(self):
'''
Consume the absorb state and return the TurboSHAKE squeeze state.
'''
state = self.state[:] # deep copy
state[self.state_offset] ^= self.D
if (((self.D & 0x80) != 0) and
(self.state_offset == (self.rate_in_bytes-1))):
state = KeccakP1600(state, 12)
state[self.rate_in_bytes-1] = state[self.rate_in_bytes-1] ^ 0x80
state = KeccakP1600(state, 12)

squeeze = TurboSHAKESqueeze()
squeeze.rate_in_bytes = self.rate_in_bytes
squeeze.state = state
squeeze.state_offset = 0
return squeeze


class TurboSHAKESqueeze:
'''TurboSHAKE in the squeeze state.'''

def next(self, length):
'''
Return the next `length` bytes of output and update the squeeze state.
'''
output = bytearray()
while length > 0:
block_size = min(length, self.rate_in_bytes-self.state_offset)
length -= block_size
output += \
self.state[self.state_offset:self.state_offset+block_size]
self.state_offset += block_size
if self.state_offset == self.rate_in_bytes:
self.state = KeccakP1600(self.state, 12)
self.state_offset = 0
return output


def NewTurboSHAKE128(D):
'''
Return the absorb state for TurboSHAKE128 with domain separation byte `D`.
'''
return TurboSHAKEAbosrb(256, D)


def testAPI(stateful, oneshot):
'''Test that the outputs of the stateful and oneshot APIs match.'''

test_cases = [
{
'fragments': [],
'lengths': [],
},
{
'fragments': [],
'lengths': [
1000,
],
},
{
'fragments': [
b'\xff' * 500,
],
'lengths': [
12,
],
},
{
'fragments': [
b'hello',
b', ',
b'',
b'world',
],
'lengths': [
1,
17,
256,
128,
0,
7,
14,
],
},
{
'fragments': [
b'\xff' * 1024,
b'\x17' * 23,
b'',
b'\xf1' * 512,
],
'lengths': [
1000,
0,
0,
14,
],

}
]

D = 99
for (i, test_case) in enumerate(test_cases):
absorb = stateful(D)
message = bytearray()
for fragment in test_case['fragments']:
absorb.update(fragment)
message += fragment
squeeze = absorb.squeeze()
output = b''
output_len = 0
for length in test_case['lengths']:
output += squeeze.next(length)
output_len += length
expected_output = oneshot(message, D, output_len)
if output != expected_output:
raise Exception('test case {} failed: got {}; want {}'.format(
i,
output.hex(),
expected_output.hex(),
))


if __name__ == '__main__':
testAPI(NewTurboSHAKE128, TurboSHAKE128)
20 changes: 10 additions & 10 deletions poc/vdaf_poplar1.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,28 +414,28 @@ def encode_idpf_field_vec(vec):
[2],
)
test_vdaf(
Poplar1.with_bits(64),
Poplar1.with_bits(128),
(
63,
(from_be_bytes(b'01234567'),),
127,
(from_be_bytes(b'0123456789abcdef'),),
),
[
from_be_bytes(b'01234567'),
from_be_bytes(b'0123456789abcdef'),
],
[1],
)
test_vdaf(
Poplar1.with_bits(64),
Poplar1.with_bits(256),
(
31,
63,
(
from_be_bytes(b'0000'),
from_be_bytes(b'0123'),
from_be_bytes(b'00000000'),
from_be_bytes(b'01234567'),
),
),
[
from_be_bytes(b'01234567'),
from_be_bytes(b'01234000'),
from_be_bytes(b'0123456789abcdef0123456789abcdef'),
from_be_bytes(b'01234567890000000000000000000000'),
],
[0, 2],
)
Expand Down
54 changes: 28 additions & 26 deletions poc/xof.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,12 @@

from __future__ import annotations

import os
import sys

from Cryptodome.Cipher import AES

kangarootwelve_path = \
"%s/draft-irtf-cfrg-kangarootwelve/py" % os.path.dirname(__file__) # nopep8
assert os.path.isdir(kangarootwelve_path) # nopep8
sys.path.append(kangarootwelve_path) # nopep8
from TurboSHAKE import TurboSHAKE128

from common import (TEST_VECTOR, VERSION, Bytes, Unsigned, concat, format_dst,
from_le_bytes, gen_rand, next_power_of_2,
print_wrapped_line, to_le_bytes, xor)

# Maximum XOF output length that will be requested by any test in this package.
# Each time `XofTurboShake128` is constructed we call `TurboSHAKE128()` once
# and fill a buffer with the output stream.
MAX_XOF_OUT_STREAM_BYTES = 2000
from turboshake import NewTurboSHAKE128, TurboSHAKE128


class Xof:
Expand Down Expand Up @@ -84,18 +71,33 @@ class XofTurboShake128(Xof):
test_vec_name = 'XofTurboShake128'

def __init__(self, seed, dst, binder):
'''
self.l = 0
self.m = to_le_bytes(len(dst), 1) + dst + seed + binder
'''
self.length_consumed = 0
self.stream = TurboSHAKE128(
to_le_bytes(len(dst), 1) + dst + seed + binder,
1,
MAX_XOF_OUT_STREAM_BYTES,
)
state = NewTurboSHAKE128(1)
state.update(to_le_bytes(len(dst), 1))
state.update(dst)
state.update(seed)
state.update(binder)
self.state = state.squeeze()

def next(self, length):
assert self.length_consumed + length < MAX_XOF_OUT_STREAM_BYTES
out = self.stream[self.length_consumed:self.length_consumed+length]
self.length_consumed += length
return out
'''
self.l += length
# Function `TurboSHAKE128(M, D, L)` is as defined in
# Section 2.2 of [TurboSHAKE].
#
# Implementation note: Rather than re-generate the output
# stream each time `next()` is invoked, most implementations
# of TurboSHAKE128 will expose an "absorb-then-squeeze" API that
# allows stateful handling of the stream.
stream = TurboSHAKE128(self.m, 1, self.l)
return stream[-length:]
'''
return self.state.next(length)


class XofFixedKeyAes128(Xof):
Expand All @@ -113,9 +115,9 @@ class XofFixedKeyAes128(Xof):
def __init__(self, seed, dst, binder):
self.length_consumed = 0

# Use SHA-3 to derive a key from the binder string and domain
# separation tag. Note that the AES key does not need to be
# kept secret from any party. However, when used with
# Use TurboSHAKE128 to derive a key from the binder string and
# domain separation tag. Note that the AES key does not need
# to be kept secret from any party. However, when used with
# IdpfPoplar, we require the binder to be a random nonce.
#
# Implementation note: This step can be cached across XOF
Expand Down

0 comments on commit 97fcdd4

Please sign in to comment.