From 66a51d68794a041586bdfc3f5a1295f622f99e8c Mon Sep 17 00:00:00 2001 From: Alon Titelman Date: Mon, 16 Sep 2024 01:29:42 +0300 Subject: [PATCH] Added rational functions for formal polynomial evaluator. --- crates/prover/src/constraint_framework/mod.rs | 1 + .../prover/src/constraint_framework/poly.rs | 100 +++++++++-- .../src/constraint_framework/poly_eval.rs | 75 ++++++-- .../src/constraint_framework/rational.rs | 163 ++++++++++++++++++ crates/prover/src/core/lookups/utils.rs | 77 ++++++++- 5 files changed, 384 insertions(+), 32 deletions(-) create mode 100644 crates/prover/src/constraint_framework/rational.rs diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 0c23dc0b0..fe1a8e8fe 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -8,6 +8,7 @@ pub mod logup; mod point; pub mod poly; pub mod poly_eval; +pub mod rational; mod simd_domain; use std::array; diff --git a/crates/prover/src/constraint_framework/poly.rs b/crates/prover/src/constraint_framework/poly.rs index f36a8d0b6..3e6b88c00 100644 --- a/crates/prover/src/constraint_framework/poly.rs +++ b/crates/prover/src/constraint_framework/poly.rs @@ -1,13 +1,12 @@ use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; use itertools::Itertools; use num_traits::{One, Zero}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::fields::FieldExpOps; /// A monic monomial consists of a list of variables and their exponents. #[derive(Debug, Hash, PartialEq, PartialOrd, Eq, Ord, Clone)] @@ -22,9 +21,13 @@ impl Monomial { } fn default() -> Monomial { - Monomial { - vars: [(0, 0)].into(), - } + Monomial { vars: [].into() } + } +} + +impl One for Monomial { + fn one() -> Self { + Monomial { vars: [].into() } } } @@ -34,6 +37,21 @@ pub struct Polynomial> { monomials: BTreeMap, } +impl + One> Polynomial { + /// Returns the polynomial x_ind. + pub fn from_var_index(ind: usize) -> Polynomial { + Self { + monomials: [( + Monomial { + vars: [(ind, 1)].into(), + }, + F::one(), + )] + .into(), + } + } +} + impl> From for Polynomial { fn from(monomial: Monomial) -> Self { Self { @@ -73,6 +91,16 @@ impl> Add for Polynomial { } } +impl Sub for Polynomial +where + F: Zero + Add + Clone + From + Neg, +{ + type Output = Self; + fn sub(self, rhs: Self) -> Self { + self + (-rhs) + } +} + #[allow(clippy::suspicious_arithmetic_impl)] impl Mul for Monomial { type Output = Self; @@ -199,14 +227,29 @@ where } } -impl FieldExpOps for Polynomial -where - F: Zero + One + Clone + Add + Mul + AddAssign + From, -{ - fn inverse(&self) -> Self { - assert!(!self.is_zero(), "0 has no inverse"); - let mut res = Self::from(BaseField::one()) / self.clone(); - res +impl From> for Polynomial { + fn from(poly: Polynomial) -> Self { + Self { + monomials: poly + .monomials + .into_iter() + .map(|(m, c)| (m, c.into())) + .collect(), + } + } +} + +impl Add> for Polynomial { + type Output = Self; + fn add(self, rhs: Polynomial) -> Self { + self + Polynomial::::from(rhs) + } +} + +impl Mul> for Polynomial { + type Output = Self; + fn mul(self, rhs: Polynomial) -> Self { + self * Polynomial::::from(rhs) } } @@ -263,10 +306,13 @@ impl Display for Monomial { impl Display for Polynomial where - F: Display + Zero + Add + Clone + From, + F: Display + Zero + Add + Clone + From + One + PartialEq, { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let mut is_first = true; + if self.is_zero() { + return write!(f, "0"); + } for (monomial, coef) in self.monomials.iter() { if !coef.is_zero() { if !is_first { @@ -274,9 +320,13 @@ where } else { is_first = false; } - write!(f, "{}", coef)?; + if !coef.is_one() { + write!(f, "{}", coef)?; + } if !monomial.vars.is_empty() { write!(f, "{}", monomial)?; + } else if coef.is_one() { + write!(f, "1")?; } } } @@ -362,6 +412,17 @@ mod tests { vars: [(0, 1), (1, 3), (2, 2)].into() } ); + + let monomial1 = Monomial { + vars: [(0, 1), (1, 2)].into(), + }; + let monomial2 = Monomial { vars: [].into() }; + assert_eq!( + monomial1.clone() * monomial2.clone(), + Monomial { + vars: [(0, 1), (1, 2)].into() + } + ); } #[test] @@ -376,7 +437,12 @@ mod tests { vars: [(4, 1), (5, 2)].into(), }; let poly1 = Polynomial:: { - monomials: [(monomial1.clone(), M31(1)), (monomial2.clone(), M31(2))].into(), + monomials: [ + (Monomial::default(), M31(12)), + (monomial1.clone(), M31(1)), + (monomial2.clone(), M31(2)), + ] + .into(), }; let poly2 = Polynomial:: { monomials: [(monomial2.clone(), M31(5)), (monomial3.clone(), -M31(8))].into(), @@ -385,6 +451,8 @@ mod tests { assert_eq!( (poly1.clone() * poly2.clone()).monomials, [ + (monomial2.clone(), M31(60)), + (monomial3.clone(), -M31(96)), (monomial1.clone() * monomial2.clone(), M31(5)), (monomial1.clone() * monomial3.clone(), -M31(8)), (monomial2.clone() * monomial2.clone(), M31(10)), diff --git a/crates/prover/src/constraint_framework/poly_eval.rs b/crates/prover/src/constraint_framework/poly_eval.rs index f8e175f23..19bdcc812 100644 --- a/crates/prover/src/constraint_framework/poly_eval.rs +++ b/crates/prover/src/constraint_framework/poly_eval.rs @@ -1,28 +1,83 @@ -use super::poly::Polynomial; +use num_traits::One; + +use super::rational::Rational; use super::EvalAtRow; -use crate::core::fields::m31::BaseField; +use crate::core::fields::m31::{BaseField, M31}; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; -pub struct PolyEvaluator<'a> {} +pub struct PolyEvaluator { + pub cur_var_index: usize, + pub constraints: Vec>, +} -impl<'a> EvalAtRow for PolyEvaluator<'a> { - type F = Polynomial; - type EF = Polynomial; +impl EvalAtRow for PolyEvaluator { + type F = Rational; + type EF = Rational; fn next_interaction_mask( &mut self, - interaction: usize, + _interaction: usize, offsets: [isize; N], ) -> [Self::F; N] { - unimplemented!() + // TODO(alont) support non-zero offsets. + assert_eq!(offsets, [0; N]); + self.cur_var_index += 1; + std::array::from_fn(|_| Self::F::from_var_index(self.cur_var_index - 1)) } fn add_constraint(&mut self, constraint: G) where Self::EF: std::ops::Mul, { - unimplemented!() + self.constraints.push(Self::EF::one() * constraint) } fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { - unimplemented!() + let [one, i, u, iu] = [ + SecureField::one(), + SecureField::from_m31_array([M31(0), M31(1), M31(0), M31(0)]), + SecureField::from_m31_array([M31(0), M31(0), M31(1), M31(0)]), + SecureField::from_m31_array([M31(0), M31(0), M31(0), M31(1)]), + ]; + values[0].clone() * one + + values[1].clone() * i + + values[2].clone() * u + + values[3].clone() * iu + } +} + +#[cfg(test)] +mod tests { + + use super::PolyEvaluator; + use crate::constraint_framework::{EvalAtRow, FrameworkEval}; + use crate::core::fields::FieldExpOps; + struct TestStruct {} + impl FrameworkEval for TestStruct { + fn log_size(&self) -> u32 { + 1 << 16 + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + 1 << 17 + } + + fn evaluate(&self, mut eval: E) -> E { + let x0 = eval.next_trace_mask(); + let x1 = eval.next_trace_mask(); + let x2 = eval.next_trace_mask(); + + eval.add_constraint(x0.clone() * x1.clone() * x2 * (x0 + x1).inverse()); + eval + } + } + + #[test] + fn test_poly_eval() { + let test_struct = TestStruct {}; + let eval = test_struct.evaluate(PolyEvaluator { + cur_var_index: 0, + constraints: vec![], + }); + + assert_eq!(eval.constraints[0].to_string(), "(x₀x₁x₂) / (x₀ + x₁)"); } } diff --git a/crates/prover/src/constraint_framework/rational.rs b/crates/prover/src/constraint_framework/rational.rs new file mode 100644 index 000000000..7d40d64d9 --- /dev/null +++ b/crates/prover/src/constraint_framework/rational.rs @@ -0,0 +1,163 @@ +use std::fmt::{Display, Formatter}; +use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub}; + +use num_traits::{One, Zero}; + +use crate::constraint_framework::poly::Polynomial; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::utils::Fraction; + +pub type Rational = Fraction, Polynomial>; + +impl Rational +where + F: From + One + AddAssign + Clone + Add, +{ + /// Returns the rational function x_ind / 1. + pub fn from_var_index(ind: usize) -> Rational { + Self::new(Polynomial::from_var_index(ind), Polynomial::one()) + } +} + +impl From> for Rational +where + F: From + One + AddAssign + Clone + Add, +{ + fn from(polynomial: Polynomial) -> Self { + Fraction::new(polynomial, Polynomial::one()) + } +} + +impl From for Rational +where + F: From + One + AddAssign + Clone + Add, +{ + fn from(value: F) -> Self { + Fraction::new(value.into(), Polynomial::one()) + } +} + +impl AddAssign for Rational +where + F: From + One + Clone + Mul + AddAssign + Add + Zero, +{ + fn add_assign(&mut self, other: BaseField) { + self.numerator = self.numerator.clone() + (self.denominator.clone() * other); + } +} + +impl Mul for Rational +where + F: From + One + Clone + Mul + AddAssign + Add + Zero, +{ + type Output = Rational; + + fn mul(self, other: BaseField) -> Self::Output { + Self::new(self.numerator * other, self.denominator) + } +} + +impl Sub for Rational +where + F: From + + One + + Clone + + Mul + + AddAssign + + Add + + Zero + + Neg, +{ + type Output = Rational; + + fn sub(self, other: BaseField) -> Self::Output { + Self::new( + self.numerator - self.denominator.clone() * other, + self.denominator, + ) + } +} + +impl From> for Rational { + fn from(rational: Rational) -> Self { + Fraction::new(rational.numerator.into(), rational.denominator.into()) + } +} + +impl Add for Rational { + type Output = Rational; + fn add(self, other: SecureField) -> Self::Output { + Rational::::from(self) + Rational::::from(other) + } +} + +impl Mul for Rational { + type Output = Rational; + fn mul(self, other: SecureField) -> Self::Output { + Rational::::from(self) * Rational::::from(other) + } +} + +impl Add for Rational +where + F: From + One + AddAssign + Clone + Add + Mul + Zero, +{ + type Output = Rational; + fn add(self, other: F) -> Rational { + Self::new( + self.numerator + self.denominator.clone() * Polynomial::::from(other), + self.denominator, + ) + } +} + +impl Sub for Rational { + type Output = Rational; + fn sub(self, other: SecureField) -> Self::Output { + self - Rational::::from(other) + } +} + +impl Mul for Rational { + type Output = Rational; + fn mul(self, other: SecureField) -> Self::Output { + self * Rational::::from(other) + } +} + +impl Add> for Rational { + type Output = Rational; + fn add(self, other: Rational) -> Self::Output { + self + Rational::::from(other) + } +} + +impl Mul> for Rational { + type Output = Rational; + fn mul(self, other: Rational) -> Self::Output { + self * Rational::::from(other) + } +} + +impl Div for Rational +where + F: From + One + AddAssign + Clone + Add + Mul + Zero, +{ + type Output = Rational; + fn div(self, other: Rational) -> Rational { + Self::new( + self.numerator.clone() * other.denominator.clone(), + self.denominator.clone() * other.numerator.clone(), + ) + } +} + +impl Display for Rational +where + F: Display + From + Clone + Zero + One + PartialEq, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "({}) / ({})", self.numerator, self.denominator) + } +} diff --git a/crates/prover/src/core/lookups/utils.rs b/crates/prover/src/core/lookups/utils.rs index 035579e5f..02c81bff4 100644 --- a/crates/prover/src/core/lookups/utils.rs +++ b/crates/prover/src/core/lookups/utils.rs @@ -1,10 +1,10 @@ use std::iter::{zip, Sum}; -use std::ops::{Add, Deref, Mul, Neg, Sub}; +use std::ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Sub}; use num_traits::{One, Zero}; use crate::core::fields::qm31::SecureField; -use crate::core::fields::{ExtensionOf, Field}; +use crate::core::fields::{ExtensionOf, Field, FieldExpOps}; /// Univariate polynomial stored as coefficients in the monomial basis. #[derive(Debug, Clone)] @@ -212,10 +212,10 @@ impl Fraction { } } -impl< - N: Clone, - D: Add + Add + Mul + Mul + Clone, - > Add for Fraction +impl Add for Fraction +where + N: Clone, + D: Add + Add + Mul + Mul + Clone, { type Output = Fraction; @@ -228,6 +228,22 @@ impl< } } +impl Sub for Fraction +where + N: Clone, + D: Sub + Add + Mul + Mul + Clone, +{ + type Output = Fraction; + + fn sub(self, rhs: Self) -> Fraction { + Fraction { + numerator: rhs.denominator.clone() * self.numerator.clone() + - self.denominator.clone() * rhs.numerator.clone(), + denominator: self.denominator * rhs.denominator, + } + } +} + impl Zero for Fraction where Self: Add, @@ -244,6 +260,25 @@ where } } +impl, D> Neg for Fraction { + type Output = Self; + fn neg(self) -> Self { + Self::new(-self.numerator, self.denominator) + } +} + +impl AddAssign for Fraction +where + N: Mul + Add + Clone, + D: Mul + Clone, +{ + fn add_assign(&mut self, rhs: Self) { + self.numerator = self.numerator.clone() * rhs.denominator.clone() + + rhs.numerator.clone() * self.denominator.clone(); + self.denominator = self.denominator.clone() * rhs.denominator; + } +} + impl Sum for Fraction where Self: Zero, @@ -254,6 +289,36 @@ where } } +impl, D: Mul> Mul for Fraction { + type Output = Self; + fn mul(self, rhs: Self) -> Self { + Self::new( + self.numerator * rhs.numerator, + self.denominator * rhs.denominator, + ) + } +} + +impl MulAssign for Fraction { + fn mul_assign(&mut self, rhs: Self) { + self.numerator *= rhs.numerator; + self.denominator *= rhs.denominator; + } +} + +impl One for Fraction { + fn one() -> Self { + Self::new(N::one(), D::one()) + } +} + +impl FieldExpOps for Fraction { + fn inverse(&self) -> Self { + assert!(!self.denominator.is_zero()); + Self::new(self.denominator.clone(), self.numerator.clone()) + } +} + /// Represents the fraction `1 / x` pub struct Reciprocal { x: T,