Special thanks to Eli ben Sasson for his kind assistance, as usual. Special thanks to Chih-Cheng Liang and Justin Drake for review, and to Ben Fisch for suggesting the reverse MIMC technique for a VDF (paper here)
Trigger warning: math and lots of python
<style> div.foo { color: white; } div.foo:hover { color: black; } </style>As a followup to Part 1 and Part 2 of this series, this post will cover what it looks like to actually implement a STARK, complete with an implementation in python. STARKs ("Scalable Transparent ARgument of Knowledge" are a technique for creating a proof that
First, some disclaimers:
- This code has not been thoroughly audited; soundness in production use cases is not guaranteed
- This code is very suboptimal (it's written in Python, what did you expect)
- STARKs "in real life" (ie. as implemented in Eli and co's production implementations) tend to use binary fields and not prime fields for application-specific efficiency reasons; however, they do stress in their writings the prime field-based approach to STARKs described here is legitimate and can be used
- There is no "one true way" to do a STARK. It's a broad category of cryptographic and mathematical constructs, with different setups optimal for different applications and constant ongoing research to reduce prover and verifier complexity and improve soundness.
- This article absolutely expects you to know how modular arithmetic and prime fields work, and be comfortable with the concepts of polynomials, interpolation and evaluation. If you don't, go back to Part 2, and also this earlier post on quadratic arithmetic programs
Now, let's get to it.
Here is the function we'll be doing a STARK of:
def mimc(inp, steps, round_constants):
start_time = time.time()
for i in range(steps-1):
inp = (inp**3 + round_constants[i % len(round_constants)]) % modulus
print("MIMC computed in %.4f sec" % (time.time() - start_time))
return inp
We choose MIMC (see paper) as the example because it is both (i) simple to understand and (ii) interesting enough to be useful in real life. The function can be viewed visually as follows:
Note: in many discussions of MIMC, you will typically see XOR used instead of +; this is because MIMC is typically done over binary fields, where addition _is_ XOR; here we are doing it over prime fields.
In our example, the round constants will be a relatively small list (eg. 64 items) that gets cycled through over and over again (that is, after k[64]
it loops back to using k[1]
).
MIMC with a very large number of rounds, as we're doing here, is useful as a verifiable delay function - a function which is difficult to compute, and particularly non-parallelizable to compute, but relatively easy to verify. MIMC by itself achieves this property to some extent because MIMC can be computed "backward" (recovering the "input" from its corresponding "output"), but computing it backward takes about 100 times longer to compute than the forward direction (and neither direction can be significantly sped up by parallelization). So you can think of computing the function in the backward direction as being the act of "computing" the non-parallelizable proof of work, and computing the function in the forward direction as being the process of "verifying" it.
$x \rightarrow x^{(2p-1)/3}$ gives the inverse of $x \rightarrow x^3$; this is true because of Fermat's Little Theorem, a theorem that despite its supposed littleness is arguably much more important to mathematics than Fermat's more famous "Last Theorem".
What we will try to achieve here is to make verification much more efficient by using a STARK - instead of the verifier having to run MIMC in the forward direction themselves, the prover, after completing the computation in the "backward direction", would compute a STARK of the computation in the "forward direction", and the verifier would simply verify the STARK. The hope is that the overhead of computing a STARK can be less than the difference in speed running MIMC forwards relative to backwards, so a prover's time would still be dominated by the initial "backward" computation, and not the (highly parallelizable) STARK computation. Verification of a STARK can be relatively fast (in our python implementation, ~0.05-0.3 seconds), no matter how long the original computation is.
All calculations are done modulo
We start off by building a convenience class that does prime field operations, as well as operations with polynomials over prime fields. The code is here. First some trivial bits:
class PrimeField():
def __init__(self, modulus):
# Quick primality test
assert pow(2, modulus, modulus) == 2
self.modulus = modulus
def add(self, x, y):
return (x+y) % self.modulus
def sub(self, x, y):
return (x-y) % self.modulus
def mul(self, x, y):
return (x*y) % self.modulus
And the Extended Euclidean Algorithm for computing modular inverses (the equivalent of computing
# Modular inverse using the extended Euclidean algorithm
def inv(self, a):
if a == 0:
return 0
lm, hm = 1, 0
low, high = a % self.modulus, self.modulus
while low > 1:
r = high//low
nm, new = hm-lm*r, high-low*r
lm, low, hm, high = nm, new, lm, low
return lm % self.modulus
The above algorithm is relatively expensive; fortunately, for the special case where we need to do many modular inverses, there's a simple mathematical trick that allows us to compute many inverses, called Montgomery batch inversion:
Using Montgomery batch inversion to compute modular inverses. Inputs purple, outputs green, multiplication gates black; the red square is the only modular inversion.
The code below implements this algorithm, with some slightly ugly special case logic so that if there are zeroes in the set of what we are inverting, it sets their inverse to 0 and moves along.
def multi_inv(self, values):
partials = [1]
for i in range(len(values)):
partials.append(self.mul(partials[-1], values[i] or 1))
inv = self.inv(partials[-1])
outputs = [0] * len(values)
for i in range(len(values), 0, -1):
outputs[i-1] = self.mul(partials[i-1], inv) if values[i-1] else 0
inv = self.mul(inv, values[i-1] or 1)
return outputs
This batch inverse algorithm will prove important later on, when we start dealing with dividing sets of evaluations of polynomials.
Now we move on to some polynomial operations. We treat a polynomial as an array, where element [1, 2, 0, 1]
). Here's the operation of evaluating a polynomial at one point:
# Evaluate a polynomial at a point
def eval_poly_at(self, p, x):
y = 0
power_of_x = 1
for i, p_coeff in enumerate(p):
y += power_of_x * p_coeff
power_of_x = (power_of_x * x) % self.modulus
return y % self.modulus
Challenge
What is the output off.eval_poly_at([4, 5, 6], 2)
if the modulus is 31?
Mouseover below for answer
$6 \cdot 2^{2} + 5 \cdot 2 + 4 = 38, 38 \bmod 31 = 7$.
There is also code for adding, subtracting, multiplying and dividing polynomials; this is textbook long addition/subtraction/multiplication/division. The one non-trivial thing is Lagrange interpolation, which takes as input a set of x and y coordinates, and returns the minimal polynomial that passes through all of those points (you can think of it as being the inverse of polynomial evaluation):
# Build a polynomial that returns 0 at all specified xs
def zpoly(self, xs):
root = [1]
for x in xs:
root.insert(0, 0)
for j in range(len(root)-1):
root[j] -= root[j+1] * x
return [x % self.modulus for x in root]
def lagrange_interp(self, xs, ys):
# Generate master numerator polynomial, eg. (x - x1) * (x - x2) * ... * (x - xn)
root = self.zpoly(xs)
# Generate per-value numerator polynomials, eg. for x=x2,
# (x - x1) * (x - x3) * ... * (x - xn), by dividing the master
# polynomial back by each x coordinate
nums = [self.div_polys(root, [-x, 1]) for x in xs]
# Generate denominators by evaluating numerator polys at each x
denoms = [self.eval_poly_at(nums[i], xs[i]) for i in range(len(xs))]
invdenoms = self.multi_inv(denoms)
# Generate output polynomial, which is the sum of the per-value numerator
# polynomials rescaled to have the right y values
b = [0 for y in ys]
for i in range(len(xs)):
yslice = self.mul(ys[i], invdenoms[i])
for j in range(len(ys)):
if nums[i][j] and ys[i]:
b[j] += nums[i][j] * yslice
return [x % self.modulus for x in b]
See the "M of N" section of this article for a description of the math. Note that we also have special-case methods lagrange_interp_4
and lagrange_interp_2
to speed up the very frequent operations of Lagrange interpolation of degree
If you read the above algorithms carefully, you might notice that Lagrange interpolation and multi-point evaluation (that is, evaluating a degree
The FFT only takes
Challenge Find a 16th root of unity mod 337 that is not an 8th root of unity.
Mouseover below for answer
59, 146, 30, 297, 278, 191, 307, 40
You could have gotten this list by doing something like[print(x) for x in range(337) if pow(x, 16, 337) == 1 and pow(x, 8, 337) != 1]
, though there is a smarter way that works for much larger moduluses: first, identify a single primitive root mod 337 (that is, not a perfect square), by looking for a valuex
such thatpow(x, 336 // 2, 337) != 1
(these are easy to find; one answer is 5), and then taking the (336 / 16)'th power of it.
Here's the algorithm (in a slightly simplified form; see code here for something slightly more optimized):
def fft(vals, modulus, root_of_unity):
if len(vals) == 1:
return vals
L = fft(vals[::2], modulus, pow(root_of_unity, 2, modulus))
R = fft(vals[1::2], modulus, pow(root_of_unity, 2, modulus))
o = [0 for i in vals]
for i, (x, y) in enumerate(zip(L, R)):
y_times_root = y*pow(root_of_unity, i, modulus)
o[i] = (x+y_times_root) % modulus
o[i+len(L)] = (x-y_times_root) % modulus
return o
def inv_fft(vals, modulus, root_of_unity):
f = PrimeField(modulus)
# Inverse FFT
invlen = f.inv(len(vals))
return [(x*invlen) % modulus for x in
fft(vals, modulus, f.inv(root_of_unity))]
You can try running it on a few inputs yourself and check that it gives results that, when you use eval_poly_at
on them, give you the answers you expect to get. For example:
>>> fft.fft([3,1,4,1,5,9,2,6], 337, 85, inv=True)
[46, 169, 29, 149, 126, 262, 140, 93]
>>> f = poly_utils.PrimeField(337)
>>> [f.eval_poly_at([46, 169, 29, 149, 126, 262, 140, 93], f.exp(85, i)) for i in range(8)]
[3, 1, 4, 1, 5, 9, 2, 6]
A Fourier transform takes as input [x[0] .... x[n-1]]
, and its goal is to output x[0] + x[1] + ... + x[n-1]
as the first element, x[0] + x[1] * 2 + ... + x[n-1] * w**(n-1)
as the second element, etc etc; a fast Fourier transform accomplishes this by splitting the data in half, doing an FFT on both halves, and then gluing the result back together.
A diagram of how information flows through the FFT computation. Notice how the FFT consists of a "gluing" step followed by two copies of the FFT on two halves of the data, and so on recursively until you're down to one element.
I recommend this for more intuition on how or why the FFT works and polynomial math in general, and this thread for some more specifics on DFT vs FFT, though be warned that most literature on Fourier transforms talks about Fourier transforms over real and complex numbers, not prime fields. If you find this too hard and don't want to understand it, just treat it as weird spooky voodoo that just works because you ran the code a few times and verified that it works, and you'll be fine too.
Reminder: now may be a good time to review and re-read Part 2
Now, we'll get into the code for making a low-degree proof. To review, a low-degree proof is a (probabilistic) proof that at least some high percentage (eg. 80%) of a given set of values represent the evaluations of some specific polynomial whose degree is much lower than the number of values given. Intuitively, just think of it as a proof that "some Merkle root that we claim represents a polynomial actually does represent a polynomial, possibly with a few errors". As input, we have:
- A set of values that we claim are the evaluation of a low-degree polynomial
- A root of unity; the x coordinates at which the polynomial is evaluated are successive powers of this root of unity
- A value
$N$ such that we are proving the degree of the polynomial is strictly less than$N$ - The modulus
Our approach is a recursive one, with two cases. First, if the degree is low enough, we just provide the entire list of values as a proof; this is the "base case". Verification of the base case is trivial: do an FFT or Lagrange interpolation or whatever else to interpolate the polynomial representing those values, and verify that its degree is
We start off by putting the values into a Merkle tree and using the Merkle root to select a pseudo-random x coordinate (special_x
). We then calculate the "column":
# Calculate the set of x coordinates
xs = get_power_cycle(root_of_unity, modulus)
column = []
for i in range(len(xs)//4):
x_poly = f.lagrange_interp_4(
[xs[i+len(xs)*j//4] for j in range(4)],
[values[i+len(values)*j//4] for j in range(4)],
)
column.append(f.eval_poly_at(x_poly, special_x))
This packs a lot into a few lines of code. The broad idea is to re-interpret the polynomial
A diagram from part 2; it helps to keep this in mind when understanding what's going on here
Our proof consists of some number (eg. 40) of random queries from the list of values of
m2 = merkelize(column)
# Pseudo-randomly select y indices to sample
# (m2[1] is the Merkle root of the column)
ys = get_pseudorandom_indices(m2[1], len(column), 40)
# Compute the Merkle branches for the values in the polynomial and the column
branches = []
for y in ys:
branches.append([mk_branch(m2, y)] +
[mk_branch(m, y + (len(xs) // 4) * j) for j in range(4)])
The verifier's job will be to verify that these five values actually do lie on the same degree
As a challenge exercise, you could try creating low-degree proofs of polynomial evaluations that have errors in them, and see how many errors you can get away passing the verifier with (hint, you'll need to modify the prove_low_degree
function; with the default prover, even one error will balloon up and cause verification to fail).
Reminder: now may be a good time to review and re-read Part 1
Now, we get to the actual meat that puts all of these pieces together: def mk_mimc_proof(inp, steps, round_constants)
(code here), which generates a proof of the execution result of running the MIMC function with the given input for some number of steps. First, some asserts:
assert steps <= 2**32 // extension_factor
assert is_a_power_of_2(steps) and is_a_power_of_2(len(round_constants))
assert len(round_constants) < steps
The extension factor is the extent to which we will be "stretching" the computational trace (the set of "intermediate values" of executing the MIMC function). We need the step count multiplied by the extension factor to be at most
Our first computation will be to generate the computational trace; that is, all of the intermediate values of the computation, from the input going all the way to the output.
# Generate the computational trace
computational_trace = [inp]
for i in range(steps-1):
computational_trace.append((computational_trace[-1]**3 + round_constants[i % len(round_constants)]) % modulus)
output = computational_trace[-1]
We then convert the computation trace into a polynomial, "laying down" successive values in the trace on successive powers of a root of unity
computational_trace_polynomial = inv_fft(computational_trace, modulus, subroot)
p_evaluations = fft(computational_trace_polynomial, modulus, root_of_unity)
Black: powers of $g_1$. Purple: powers of $g_2$. Orange: 1. You can look at successive roots of unity as being arranged in a circle in this way. We are "laying" the computational trace along powers of $g_1$, and then extending it compute the values of the same polynomial at the intermediate values (ie. the powers of $g_2$).
We can convert the round constants of MIMC into a polynomial. Because these round constants loop around very frequently (in our tests, every 64 steps), it turns out that they form a degree-64 polynomial, and we can fairly easily compute its expression, and its extension:
skips2 = steps // len(round_constants)
constants_mini_polynomial = fft(round_constants, modulus, f.exp(subroot, skips2), inv=True)
constants_polynomial = [0 if i % skips2 else constants_mini_polynomial[i//skips2] for i in range(steps)]
constants_mini_extension = fft(constants_mini_polynomial, modulus, f.exp(root_of_unity, skips2))
Suppose there are 8192 steps of execution and 64 round constants. Here is what we are doing: we are doing an FFT to compute the round constants as a function of
We now, as in the Fibonacci example in Part 1, calculate
# Create the composed polynomial such that
# C(P(x), P(g1*x), K(x)) = P(g1*x) - P(x)**3 - K(x)
c_of_p_evaluations = [(p_evaluations[(i+extension_factor)%precision] -
f.exp(p_evaluations[i], 3) -
constants_mini_extension[i % len(constants_mini_extension)])
% modulus for i in range(precision)]
print('Computed C(P, K) polynomial')
Note that here we are no longer working with polynomials in coefficient form; we are working with the polynomials in terms of their evaluations at successive powers of the higher-order root of unity.
c_of_p
is intended to be
There is an algebraic theorem that proves that if
We chose the particular arrangement of lower and higher order roots of unity (rather than, say, laying the computational trace along the first few powers of the higher order root of unity) because it turns out that computing
# Compute D(x) = Q(x) / Z(x)
# Z(x) = (x^steps - 1) / (x - x_atlast_step)
z_num_evaluations = [xs[(i * steps) % precision] - 1 for i in range(precision)]
z_num_inv = f.multi_inv(z_num_evaluations)
z_den_evaluations = [xs[i] - last_step_position for i in range(precision)]
d_evaluations = [cp * zd * zni % modulus for cp, zd, zni in zip(c_of_p_evaluations, z_den_evaluations, z_num_inv)]
print('Computed D polynomial')
Notice that we compute the numerator and denominator of
Because
Probabilistically checking
# Compute interpolant of ((1, input), (x_atlast_step, output))
interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
i_evaluations = [f.eval_poly_at(interpolant, x) for x in xs]
zeropoly2 = f.mul_polys([-1, 1], [-last_step_position, 1])
inv_z2_evaluations = f.multi_inv([f.eval_poly_at(quotient, x) for x in xs])
# B = (P - I) / Z2
b_evaluations = [((p - i) * invq) % modulus for p, i, invq in zip(p_evaluations, i_evaluations, inv_z2_evaluations)]
print('Computed B polynomial')
The argument is as follows. The prover wants to prove
Purple: computational trace polynomial (P). Green: interpolant (I) (notice how the interpolant is constructed to equal the input (which should be the first step of the computational trace) at x=1 and the output (which should be the last step of the computational trace) at $x=g^{steps-1}$. Red: $P - I$. Yellow: the minimal polynomial that equals $0$ at $x=1$ and $x=g^{steps-1}$ (that is, $Z_2$). Pink: $\frac{P - I}{Z_2}$.
Challenge Suppose you wanted to also prove that the value in the computational trace after the 703rd computational step is equal to 8018284612598740. How would you modify the above algorithm to do that?
Mouseover below for answer
Set $I(x)$ to be the interpolant of $(1, input), (g^{703}, 8018284612598740), (last\_step, output)$, and make a proof by providing the quotient $B(x) = \frac{P(x) - I(x)}{(x - 1) \cdot (x - g^{703}) \cdot (x - last\_step)}$
Now, we commit to the Merkle root of
# Compute their Merkle roots
mtree = merkelize([pval.to_bytes(32, 'big') +
dval.to_bytes(32, 'big') +
bval.to_bytes(32, 'big') for
pval, dval, bval in zip(p_evaluations, d_evaluations, b_evaluations)])
print('Computed hash root')
Now, we need to prove that
k1 = int.from_bytes(blake(mtree[1] + b'\x01'), 'big')
k2 = int.from_bytes(blake(mtree[1] + b'\x02'), 'big')
k3 = int.from_bytes(blake(mtree[1] + b'\x03'), 'big')
k4 = int.from_bytes(blake(mtree[1] + b'\x04'), 'big')
# Compute the linear combination. We don't even bother calculating it
# in coefficient form; we just compute the evaluations
root_of_unity_to_the_steps = f.exp(root_of_unity, steps)
powers = [1]
for i in range(1, precision):
powers.append(powers[-1] * root_of_unity_to_the_steps % modulus)
l_evaluations = [(d_evaluations[i] +
p_evaluations[i] * k1 + p_evaluations[i] * k2 * powers[i] +
b_evaluations[i] * k3 + b_evaluations[i] * powers[i] * k4) % modulus
for i in range(precision)]
Unless all three of the polynomials have the right low degree, it's almost impossible that a randomly selected linear combination of them will (you have to get extremely lucky for the terms to cancel), so this is sufficient.
We want to prove that the degree of D is less than
Now, we do some spot checks of all of the polynomials. We generate some random indices, and provide the Merkle branches of the polynomial evaluated at those indices:
# Do some spot checks of the Merkle tree at pseudo-random coordinates, excluding
# multiples of `extension_factor`
branches = []
samples = spot_check_security_factor
positions = get_pseudorandom_indices(l_mtree[1], precision, samples,
exclude_multiples_of=extension_factor)
for pos in positions:
branches.append(mk_branch(mtree, pos))
branches.append(mk_branch(mtree, (pos + skips) % precision))
branches.append(mk_branch(l_mtree, pos))
print('Computed %d spot checks' % samples)
The get_pseudorandom_indices
function returns some random indices in the range [0...precision-1], and the exclude_multiples_of
parameter tells it to not give values that are multiples of the given parameter (here, extension_factor
). This ensures that we do not sample along the original computational trace, where we are likely to get wrong answers.
The proof (~250-500 kilobytes altogether) consists of a set of Merkle roots, the spot-checked branches, and a low-degree proof of the random linear combination:
o = [mtree[1],
l_mtree[1],
branches,
prove_low_degree(l_evaluations, root_of_unity, steps * 2, modulus, exclude_multiples_of=extension_factor)]
The largest parts of the proof in practice are the Merkle branches, and the FRI proof, which consists of even more branches. And here's the "meat" of the verifier:
for i, pos in enumerate(positions):
x = f.exp(G2, pos)
x_to_the_steps = f.exp(x, steps)
mbranch1 = verify_branch(m_root, pos, branches[i*3])
mbranch2 = verify_branch(m_root, (pos+skips)%precision, branches[i*3+1])
l_of_x = verify_branch(l_root, pos, branches[i*3 + 2], output_as_int=True)
p_of_x = int.from_bytes(mbranch1[:32], 'big')
p_of_g1x = int.from_bytes(mbranch2[:32], 'big')
d_of_x = int.from_bytes(mbranch1[32:64], 'big')
b_of_x = int.from_bytes(mbranch1[64:], 'big')
zvalue = f.div(f.exp(x, steps) - 1,
x - last_step_position)
k_of_x = f.eval_poly_at(constants_mini_polynomial, f.exp(x, skips2))
# Check transition constraints Q(x) = Z(x) * D(x)
assert (p_of_g1x - p_of_x ** 3 - k_of_x - zvalue * d_of_x) % modulus == 0
# Check boundary constraints B(x) * Z2(x) + I(x) = P(x)
interpolant = f.lagrange_interp_2([1, last_step_position], [inp, output])
zeropoly2 = f.mul_polys([-1, 1], [-last_step_position, 1])
assert (p_of_x - b_of_x * f.eval_poly_at(zeropoly2, x) -
f.eval_poly_at(interpolant, x)) % modulus == 0
# Check correctness of the linear combination
assert (l_of_x - d_of_x -
k1 * p_of_x - k2 * p_of_x * x_to_the_steps -
k3 * b_of_x - k4 * b_of_x * x_to_the_steps) % modulus == 0
At every one of the positions that the prover provides a Merkle proof for, the verifier checks the Merkle proof, and checks that verify_low_degree_proof(l_root, root_of_unity, fri_proof, steps * 2, modulus, exclude_multiples_of=extension_factor)
to verify the FRI proof. And we're done!
Well, not really; soundness analysis to prove how many spot-checks for the cross-polynomial checking and for the FRI are necessary is really tricky. But that's all there is to the code, at least if you don't care about making even crazier optimizations. When I run the code above, we get a STARK proving "overhead" of about 300-400x (eg. a MIMC computation that takes 0.2 seconds to calculate takes 60 second to prove), suggesting that with a 4-core machine computing the STARK of the MIMC computation in the forward direction could actually be faster than computing MIMC in the backward direction. That said, these are both relatively inefficient implementations in python, and the proving to running time ratio for properly optimized implementations may be different. Also, it's worth pointing out that the STARK proving overhead for MIMC is remarkably low, because MIMC is almost perfectly "arithmetizable" - it's mathematical form is very simple. For "average" computations, which contain less arithmetically clean operations (eg. checking if a number is greater or less than another number), the overhead is likely much higher, possibly around 10000-50000x.