Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization: Sumcheck::prove_cubic_with_additive_term #4

Open
sragss opened this issue Jan 17, 2024 · 2 comments
Open

Optimization: Sumcheck::prove_cubic_with_additive_term #4

sragss opened this issue Jan 17, 2024 · 2 comments

Comments

@sragss
Copy link
Collaborator

sragss commented Jan 17, 2024

Sumcheck::prove_cubic_with_additive_term seems suboptimal. Currently it takes ~6% of Spartan e2e time.

https://github.com/a16z/Spartan2/blob/uniform_r1cs_shape/src/spartan/sumcheck.rs#L251

Some ideas on optimization follow.

0 / 1 Checking

comb_func_outer gets passed into Sumcheck::prove_cubic_with_additive_term from Snark::prove. This combination function is f(a,b,c,d) = a * (b * c - d). There are clear optimizations to be had here in the case that any of the terms are 0 / 1. Specifically if a is zero, we should short circuit. The rest are less relevant but can theoretically save up to 66% of field multiplications.

compute_eval_points_cubic

This function is parallelized over the length of the 4 MLEs passed in, but is missing some optimizations.
This is the binding function:

for i in 0..mle_evals.len() / 2:
    low = mle_evals[i];
    high = mle_evals[2*i];
    f(r) = low + r * (high - low);

We compute f(r) for r = 0, 2, 3 (the 1-th eval can be derived).

To expand this a bit we have:

f(0) = low + 0 * (high - low) = low
f(2) = low + 2 * (high - low) = high + high - low
f(3) = low + 3 * (high - low) = f(2) + high - low

We can precompute m = high - low.

m =  high - low
f(0) = low
f(2) = high + m
f(3) = f(2) + m

This is more efficient by a few field additions.

Next, notice that if high / low have a high probability of being 0 / 1 we have some interesting properties:

  • high == low => m=0 => f(2) = f(3) = high
  • m=0 => comb_func(f_a(2), f_b(2), f_c(2), f_d(2)) == comb_func(f_a(3), f_b(3), f_c(3), f_d(3))
    There are some other combos that are likely less relevant and rarer. May be worth exploring.

Poly Binding

At the end of each round of Sumcheck::prove_cubic_with_additive_term the 4 polynomials are bound (bound_poly_var_top). These can all be executed in parallel rather than serially. The bound_poly_var_top function itself is parallelized, but worth determining experimentally if a changed parallelization shape is more efficient from a memory contention perspective (I suspect it will be).

Inline Poly Binding

The two sections of the sumcheck loop are to evaluate the joint polynomial p(b,...) = f_a(b,..)*[f_b(b,..)*f_c(b,..) - f_d(b,..)] over the boolean hypercube. then to bind each of the multilinear polynomials f_a / f_b / f_c / f_d to a point r derived from the prior evaluation. I usually call these the eval loop then the binding loop. Interestingly they perform much of the same work. Above (in compute_eval_points_cubic) I describe the eval loop algorithm. The binding loop does the same but for f(r) instead of f({0,2,3}). This means it may be plausible to keep m around to compute low' = low + r * m. I believe this saves exactly one field addition per step at the cost of significant RAM, but plausible there are some memory performance improvements when tested experimentally.

@sragss
Copy link
Collaborator Author

sragss commented Jan 20, 2024

Here's a poly A, B, C density chart for Sha256:

Poly Az ====================== 2097152
0x0000000000000000000000000000000000000000000000000000000000000000: 1339241
0x0000000000000000000000000000000000000000000000000000000000000001: 390838
0x0000000000000000000000000000000000000000000000000000000000000002: 266101
0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000: 99671
Total with only a single appearance: 1301


Poly Bz ====================== 2097152
0x0000000000000000000000000000000000000000000000000000000000000000: 1480606
0x0000000000000000000000000000000000000000000000000000000000000001: 616546
Total with only a single appearance: 0


Poly uCz_E ====================== 2097152
0x0000000000000000000000000000000000000000000000000000000000000000: 1863816
0x0000000000000000000000000000000000000000000000000000000000000002: 132442
0x0000000000000000000000000000000000000000000000000000000000000001: 49820
0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000: 49773
Total with only a single appearance: 1301
Time elapsed is: 7.962553084s

@sragss
Copy link
Collaborator Author

sragss commented Jan 20, 2024

An optimized 0/1 comb function makes it ~33% faster:


        let comb_func_outer = |poly_A_comp: &G::Scalar,
                               poly_B_comp: &G::Scalar,
                               poly_C_comp: &G::Scalar,
                               poly_D_comp: &G::Scalar|
         -> G::Scalar {
          // Goal: compute *poly_A_comp * (*poly_B_comp * *poly_C_comp - *poly_D_comp) fast.
          // poly_A we know to be uniformly random
          // poly_B: A matrix, poly_C: B matrix, poly_D: C matrix
          if poly_B_comp.eq(&G::Scalar::ZERO) || poly_C_comp.eq(&G::Scalar::ZERO) {
            *poly_A_comp * poly_D_comp.neg()
          } else {
            let inner = if poly_B_comp.eq(&G::Scalar::ONE) {
              *poly_C_comp - *poly_D_comp
            } else if poly_C_comp.eq(&G::Scalar::ONE)  {
              *poly_B_comp - *poly_D_comp
            } else {
              *poly_B_comp * *poly_C_comp - *poly_D_comp
            };
            *poly_A_comp * inner
          }
        };

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant