Skip to content

Commit

Permalink
IDPF: Simplify code, add constant-time notes (#419)
Browse files Browse the repository at this point in the history
  • Loading branch information
divergentdave authored Sep 11, 2024
1 parent 6e7b1c1 commit 6c1e320
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 66 deletions.
62 changes: 39 additions & 23 deletions draft-irtf-cfrg-vdaf.md
Original file line number Diff line number Diff line change
Expand Up @@ -5061,8 +5061,6 @@ def gen(
ctrl = [Field2(0), Field2(1)]
public_share = []
for level in range(self.BITS):
field: type[Field]
field = cast(type[Field], self.current_field(level))
keep = (alpha >> (self.BITS - level - 1)) & 1
lose = 1 - keep
bit = Field2(keep)
Expand All @@ -5075,12 +5073,24 @@ def gen(
t0[1] + t1[1] + bit,
)

x0 = xor(s0[keep], ctrl[0].conditional_select(seed_cw))
x1 = xor(s1[keep], ctrl[1].conditional_select(seed_cw))
# Implementation note: these conditional XORs and
# input-dependent array indices should be replaced with
# constant-time selects in practice in order to reduce
# leakage via timing side channels.
if ctrl[0].as_unsigned():
x0 = xor(s0[keep], seed_cw)
ctrl[0] = t0[keep] + ctrl_cw[keep]
else:
x0 = s0[keep]
ctrl[0] = t0[keep]
if ctrl[1].as_unsigned():
x1 = xor(s1[keep], seed_cw)
ctrl[1] = t1[keep] + ctrl_cw[keep]
else:
x1 = s1[keep]
ctrl[1] = t1[keep]
(seed[0], w0) = self.convert(level, x0, nonce)
(seed[1], w1) = self.convert(level, x1, nonce)
ctrl[0] = t0[keep] + ctrl[0] * ctrl_cw[keep]
ctrl[1] = t1[keep] + ctrl[1] * ctrl_cw[keep]

if level < self.BITS - 1:
b = cast(list[Field], beta_inner[level])
Expand All @@ -5092,12 +5102,13 @@ def gen(
)

w_cw = vec_add(vec_sub(b, w0), w1)
# Implementation note: here we negate the correction word if
# the control bit `ctrl[1]` is set. We avoid branching on the
# value in order to reduce leakage via timing side channels.
mask = field(1) - field(2) * field(ctrl[1].as_unsigned())
for i in range(len(w_cw)):
w_cw[i] *= mask
# Implementation note: this conditional negation should be
# replaced with a constant time select or a constant time
# multiplication in practice in order to reduce leakage via
# timing side channels.
if ctrl[1].as_unsigned():
for i in range(len(w_cw)):
w_cw[i] = -w_cw[i]

public_share.append((seed_cw, ctrl_cw, w_cw))
return (public_share, key)
Expand Down Expand Up @@ -5189,26 +5200,31 @@ def eval_next(
bit of the prefix corresponding to the next level of the tree.
"""

field = self.current_field(level)
seed_cw = correction_word[0]
ctrl_cw = correction_word[1]
w_cw = cast(list[Field], correction_word[2])
(s, t) = self.extend(level, prev_seed, nonce)
s[0] = xor(s[0], prev_ctrl.conditional_select(seed_cw))
s[1] = xor(s[1], prev_ctrl.conditional_select(seed_cw))
t[0] += ctrl_cw[0] * prev_ctrl
t[1] += ctrl_cw[1] * prev_ctrl

# Implementation note: these conditional operations and
# input-dependent array indices should be replaced with
# constant-time selects in practice in order to reduce leakage
# via timing side channels.
if prev_ctrl.as_unsigned():
s[0] = xor(s[0], seed_cw)
s[1] = xor(s[1], seed_cw)
t[0] += ctrl_cw[0]
t[1] += ctrl_cw[1]

next_ctrl = t[bit]
convert_output = self.convert(level, s[bit], nonce)
next_seed = convert_output[0]
y = cast(list[Field], convert_output[1])
# Implementation note: here we add the correction word to the
# output if `next_ctrl` is set. We avoid branching on the value
# of the control bit in order to reduce side channel leakage.
mask = cast(Field, field(next_ctrl.as_unsigned()))
for i in range(len(y)):
y[i] += w_cw[i] * mask
# Implementation note: this conditional addition should be
# replaced with a constant-time select in practice in order to
# reduce leakage via timing side channels.
if next_ctrl.as_unsigned():
for i in range(len(y)):
y[i] += w_cw[i]

return (next_seed, next_ctrl, cast(FieldVec, y))
~~~
Expand Down
3 changes: 0 additions & 3 deletions poc/tests/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ def test_field2(self) -> None:
self.assertEqual(Field2(1) + Field2(1), Field2(0))
self.assertEqual(Field2(1) * Field2(1), Field2(1))
self.assertEqual(-Field2(1), Field2(1))
self.assertEqual(Field2(1).conditional_select(b'hello'), b'hello')
self.assertEqual(Field2(0).conditional_select(
b'hello'), bytes([0, 0, 0, 0, 0]))

def test_interp(self) -> None:
# Test polynomial interpolation.
Expand Down
17 changes: 0 additions & 17 deletions poc/vdaf_poc/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,23 +176,6 @@ class Field2(Field):
# Sage finite field object.
gf = GF(MODULUS)

def conditional_select(self, inp: bytes) -> bytes:
"""
Return `inp` unmodified if `self == 1`; otherwise return the all-zero
string of the same length.
Implementation note: to protect the code from timing side channels, it
is important to implement this algorithm in constant time.
"""

# Convert the element into a bitmask such that `m == 255` if
# `self == 1` and `m == 0` otherwise.
m = 0
v = self.as_unsigned()
for i in range(8):
m |= v << i
return bytes(map(lambda x: m & x, inp))


class Field64(NttField):
"""The finite field GF(2^32 * 4294967295 + 1)."""
Expand Down
62 changes: 39 additions & 23 deletions poc/vdaf_poc/idpf_bbcggi21.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def gen(
ctrl = [Field2(0), Field2(1)]
public_share = []
for level in range(self.BITS):
field: type[Field]
field = cast(type[Field], self.current_field(level))
keep = (alpha >> (self.BITS - level - 1)) & 1
lose = 1 - keep
bit = Field2(keep)
Expand All @@ -105,12 +103,24 @@ def gen(
t0[1] + t1[1] + bit,
)

x0 = xor(s0[keep], ctrl[0].conditional_select(seed_cw))
x1 = xor(s1[keep], ctrl[1].conditional_select(seed_cw))
# Implementation note: these conditional XORs and
# input-dependent array indices should be replaced with
# constant-time selects in practice in order to reduce
# leakage via timing side channels.
if ctrl[0].as_unsigned():
x0 = xor(s0[keep], seed_cw)
ctrl[0] = t0[keep] + ctrl_cw[keep]
else:
x0 = s0[keep]
ctrl[0] = t0[keep]
if ctrl[1].as_unsigned():
x1 = xor(s1[keep], seed_cw)
ctrl[1] = t1[keep] + ctrl_cw[keep]
else:
x1 = s1[keep]
ctrl[1] = t1[keep]
(seed[0], w0) = self.convert(level, x0, nonce)
(seed[1], w1) = self.convert(level, x1, nonce)
ctrl[0] = t0[keep] + ctrl[0] * ctrl_cw[keep]
ctrl[1] = t1[keep] + ctrl[1] * ctrl_cw[keep]

if level < self.BITS - 1:
b = cast(list[Field], beta_inner[level])
Expand All @@ -122,12 +132,13 @@ def gen(
)

w_cw = vec_add(vec_sub(b, w0), w1)
# Implementation note: here we negate the correction word if
# the control bit `ctrl[1]` is set. We avoid branching on the
# value in order to reduce leakage via timing side channels.
mask = field(1) - field(2) * field(ctrl[1].as_unsigned())
for i in range(len(w_cw)):
w_cw[i] *= mask
# Implementation note: this conditional negation should be
# replaced with a constant time select or a constant time
# multiplication in practice in order to reduce leakage via
# timing side channels.
if ctrl[1].as_unsigned():
for i in range(len(w_cw)):
w_cw[i] = -w_cw[i]

public_share.append((seed_cw, ctrl_cw, w_cw))
return (public_share, key)
Expand Down Expand Up @@ -214,26 +225,31 @@ def eval_next(
bit of the prefix corresponding to the next level of the tree.
"""

field = self.current_field(level)
seed_cw = correction_word[0]
ctrl_cw = correction_word[1]
w_cw = cast(list[Field], correction_word[2])
(s, t) = self.extend(level, prev_seed, nonce)
s[0] = xor(s[0], prev_ctrl.conditional_select(seed_cw))
s[1] = xor(s[1], prev_ctrl.conditional_select(seed_cw))
t[0] += ctrl_cw[0] * prev_ctrl
t[1] += ctrl_cw[1] * prev_ctrl

# Implementation note: these conditional operations and
# input-dependent array indices should be replaced with
# constant-time selects in practice in order to reduce leakage
# via timing side channels.
if prev_ctrl.as_unsigned():
s[0] = xor(s[0], seed_cw)
s[1] = xor(s[1], seed_cw)
t[0] += ctrl_cw[0]
t[1] += ctrl_cw[1]

next_ctrl = t[bit]
convert_output = self.convert(level, s[bit], nonce)
next_seed = convert_output[0]
y = cast(list[Field], convert_output[1])
# Implementation note: here we add the correction word to the
# output if `next_ctrl` is set. We avoid branching on the value
# of the control bit in order to reduce side channel leakage.
mask = cast(Field, field(next_ctrl.as_unsigned()))
for i in range(len(y)):
y[i] += w_cw[i] * mask
# Implementation note: this conditional addition should be
# replaced with a constant-time select in practice in order to
# reduce leakage via timing side channels.
if next_ctrl.as_unsigned():
for i in range(len(y)):
y[i] += w_cw[i]

return (next_seed, next_ctrl, cast(FieldVec, y))

Expand Down

0 comments on commit 6c1e320

Please sign in to comment.