diff --git a/src/circle.py b/src/circle.py index 18a9295..839d6f9 100644 --- a/src/circle.py +++ b/src/circle.py @@ -130,23 +130,50 @@ def evaluate_at_point(evals, domain, point, debug=False): return sum([lagrange_den[i] * evals[i] for i in range(len(evals))]) * lagrange_num +def eval_at_p_recursive(evals, twiddle, debug=False): + if len(evals) == 1: + return evals[0] + else: + f0 = eval_at_p_recursive(evals[:len(evals)//2], pi(twiddle), debug) + f1 = eval_at_p_recursive(evals[len(evals)//2:], pi(twiddle), debug) + return f0 + f1 * twiddle + +def eval_at_point_raw(evals, domain, point, debug=False): + + x, y = point + poly = CFFT.vec_2_poly(evals, domain) + coeffs = CFFT.ifft(poly) + + left, right = coeffs[:len(coeffs)//2], coeffs[len(coeffs)//2:] + left_eval = eval_at_p_recursive(left, x) + right_eval = eval_at_p_recursive(right, x) + + return left_eval + right_eval * y + def deep_quotient_vanishing_part(x, zeta, alpha_pow_width, debug=False): - v_p = lambda p, at: (1 - (p - at)[0], -(p - at)[1]) + v_p = lambda p, at: (1 - group_mul(p, group_inv(at))[0], -group_mul(p, group_inv(at))[1]) re_v_zeta, im_v_zeta = v_p(x, zeta) # if debug: print('re_v_zeta:', re_v_zeta, 'im_v_zeta:', im_v_zeta) - return (re_v_zeta - alpha_pow_width * im_v_zeta, re_v_zeta ** 2 + im_v_zeta ** 2) + # return (re_v_zeta - alpha_pow_width * im_v_zeta, re_v_zeta ** 2 + im_v_zeta ** 2) + return (re_v_zeta - im_v_zeta * alpha_pow_width, re_v_zeta ** 2 + im_v_zeta ** 2) def deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta, debug=False): vp_nums, vp_demons = zip(*[(deep_quotient_vanishing_part(x, zeta, alpha, debug)) for x in domain]) vp_denom_invs = batch_multiplicative_inverse(vp_demons) if debug: print('vp_nums:', vp_nums, 'vp_denom_invs:', vp_denom_invs, 'p_at_zeta:', p_at_zeta, 'evals:', evals) - return [vp_nums[i] * vp_denom_invs[i] * (-p_at_zeta + evals[i]) for i in range(len(evals))] + return [vp_denom_invs[i] * vp_nums[i] * group_mul(group_inv(p_at_zeta), evals[i]) for i in range(len(evals))] def deep_quotient_reduce_row(alpha, x, zeta, ps_at_x, ps_at_zeta, debug=False): vp_num, vp_denom = deep_quotient_vanishing_part(x, zeta, alpha) if debug: print('vp_num:', vp_num, 'vp_denom:', vp_denom, 'ps_at_x:', ps_at_x, 'ps_at_zeta:', ps_at_zeta) - return vp_num * (-ps_at_zeta + ps_at_x) / vp_denom + return vp_num * group_mul(group_inv(ps_at_zeta), ps_at_x) / vp_denom + +def deep_quotient_reduce_raw(evals, domain, alpha, zeta, p_at_zeta, debug=False): + res = [] + for ps_at_x, x in zip(evals, domain): + res.append(deep_quotient_reduce_row(alpha, x, zeta, ps_at_x, p_at_zeta, debug)) + return res def extract_lambda(lde, log_blowup, debug=False): if debug: @@ -208,13 +235,6 @@ def combine(cosets): res += [t] return res -# test twin_cosets -tcs = twin_cosets(2, 4) -for tc in tcs: - for t in tc: - assert t in standard_position_cosets[log_2(8)] -assert combine(tcs) == standard_position_cosets[log_2(8)], f'combine error, {combine(tcs)}, {standard_position_cosets[log_2(8)]}' - class CFFT: @classmethod def _ifft_first_step(cls, f): @@ -344,8 +364,8 @@ def vec_2_poly(cls, vec, domain): return f @classmethod - def poly_2_vec(cls, poly): - return [poly[t] for t in poly] + def poly_2_vec(cls, poly, domain): + return [poly[t] for t in domain] @classmethod def extrapolate(cls, evals, domain, blowup_factor): @@ -354,8 +374,7 @@ def extrapolate(cls, evals, domain, blowup_factor): cosets = twin_cosets(blowup_factor, len(evals)) res = [] for coset in cosets: - res += [cls.fft(coeffs, coset)] - res = [cls.poly_2_vec(x) for x in res] + res += [cls.poly_2_vec(cls.fft(coeffs, coset), coset)] return combine(res) class FRI: @@ -625,11 +644,11 @@ def open(cls, evals, evals_commit, zeta, log_blowup, transcript, num_queries, de # evaluate the polynomial at the point zeta domain = cls.natural_domain_for_degree(len(evals)) - p_at_zeta = evaluate_at_point(evals, domain, zeta, debug) + p_at_zeta = eval_at_point_raw(evals, domain, zeta, debug) if debug: print('p_at_zeta:', p_at_zeta) # deep quotient - reduced_opening = deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta, debug) + reduced_opening = deep_quotient_reduce_raw(evals, domain, alpha, zeta, p_at_zeta, debug) if debug: print('reduced_opening:', reduced_opening) # extract lambda first_layer, lambda_ = extract_lambda(reduced_opening, log_blowup, debug) @@ -753,4 +772,4 @@ def open_input(index, input_proof): transcript = MerlinTranscript(b'circle pcs') transcript.append_message(b'commitment', bytes(str(commitment.root), 'ascii')) - CirclePCS.verify(commitment.root, domain, log_blowup, point, evaluate_at_point(evals, CirclePCS.natural_domain_for_degree(len(evals)), point), proof, transcript, True) + CirclePCS.verify(commitment.root, domain, log_blowup, point, eval_at_point_raw(evals, CirclePCS.natural_domain_for_degree(len(evals)), point), proof, transcript, True) diff --git a/tests/test_circle.py b/tests/test_circle.py new file mode 100644 index 0000000..199951d --- /dev/null +++ b/tests/test_circle.py @@ -0,0 +1,92 @@ +import unittest +from sage.all import * +import sys + +sys.path.append('src') +sys.path.append('../src') + +from circle import CFFT, F31, C31, FRI, eval_at_point_raw, twin_cosets, combine, standard_position_cosets, log_2, deep_quotient_reduce, deep_quotient_reduce_raw, g_30 +from merlin.merlin_transcript import MerlinTranscript + +def fold(lde, domain, chunk_size, fold_y=False): + if fold_y: + assert len(domain) == len(lde), f'len(domain) != len(lde), {len(domain)}, {len(lde)}' + else: + assert len(domain) == len(lde) * 2, f'len(domain) != len(lde) * 2, {len(domain)}, {len(lde) * 2}' + res = [] + for j in range(len(lde) // chunk_size): + for i in range(chunk_size // 2): + left = lde[j * chunk_size + i] + right = lde[(j + 1) * chunk_size - i - 1] + t = domain[i][1 if fold_y else 0] + # print('t:', t) + f0 = (left + right) / F31(2) + f1 = (left - right) / (F31(2) * t) + assert f0 + f1 * t == left + assert f0 - f1 * t == right + res += [f0 + f1 * 3] + return res + +class TestCircle(unittest.TestCase): + def test_twin_cosets(self): + # test twin_cosets + tcs = twin_cosets(2, 4) + for tc in tcs: + for t in tc: + assert t in standard_position_cosets[log_2(8)] + assert combine(tcs) == standard_position_cosets[log_2(8)], f'combine error, {combine(tcs)}, {standard_position_cosets[log_2(8)]}' + + def test_extrapolate(self): + evals = [1, 2, 3, 4] + domain = standard_position_cosets[log_2(len(evals))] + blowup_factor = 2 + lde = CFFT.extrapolate(evals, domain, blowup_factor) + + assert len(lde) == len(evals) * blowup_factor, f'len(lde) != len(evals) * blowup_factor, {len(lde)}, {len(evals) * blowup_factor}' + for i, p in enumerate(standard_position_cosets[log_2(len(evals) * blowup_factor)]): + assert eval_at_point_raw(evals, domain, p) == lde[i], f'evaluate_at_point error, {eval_at_point_raw(evals, domain, p)}, {lde[i]}' + + def test_fold(self): + evals = [1, 2, 3, 4] + domain = standard_position_cosets[log_2(len(evals))] + blowup_factor = 2 + lde = CFFT.extrapolate(evals, domain, blowup_factor) + + domain_lde = standard_position_cosets[log_2(len(evals) * blowup_factor)] + # print('domain_lde:', domain_lde) + folded = fold(lde, domain_lde, len(lde), fold_y=True) + folded_folded = fold(folded, domain_lde, len(lde) // 2, fold_y=False) + assert folded_folded[0] == folded_folded[1], f'folded_folded[0] != folded_folded[1], {folded_folded[0]}, {folded_folded[1]}' + + def test_fri_prove(self): + evals = [1, 2, 3, 4] + domain = standard_position_cosets[log_2(len(evals))] + blowup_factor = 2 + lde = CFFT.extrapolate(evals, domain, blowup_factor) + + domain_lde = standard_position_cosets[log_2(len(evals) * blowup_factor)] + # print('domain_lde:', domain_lde) + folded = fold(lde, domain_lde, len(lde), fold_y=True) + + f0 = [lde[0]] + lde[3:5] + [lde[7]] + f1 = lde[1:3] + lde[5:7] + + tcs = twin_cosets(2, 4) + assert CFFT.ifft(CFFT.vec_2_poly(f0, tcs[0])) == CFFT.ifft(CFFT.vec_2_poly(f1, tcs[1])) + + transcript = MerlinTranscript(b'TEST') + _fri_proof = FRI.prove(folded, blowup_factor, [x[0] for x in domain_lde[:len(folded)]], transcript, lambda x: None, 1) + + def test_deep_quotient_reduce(self): + evals = [C31(1), C31(2), C31(3), C31(4)] + domain = standard_position_cosets[log_2(len(evals))] + alpha = 3 + zeta = g_30 ** 6 + p_at_zeta = eval_at_point_raw(evals, domain, zeta) + + reduced = deep_quotient_reduce(evals, domain, alpha, zeta, p_at_zeta) + expected = deep_quotient_reduce_raw(evals, domain, alpha, zeta, p_at_zeta) + assert reduced == expected, f'deep_quotient_reduce error, {reduced}, {expected}' + +if __name__ == '__main__': + unittest.main()