diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index c2b523014..507507397 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -18,6 +18,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, structs::{ChallengeId, RAMType, WitnessId}, + utils::SimpleVecPool, }; #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -141,6 +142,84 @@ impl Expression { } } + #[allow(clippy::too_many_arguments)] + pub fn evaluate_with_instance_pool Vec, PF2: Fn() -> Vec>( + &self, + fixed_in: &impl Fn(&Fixed) -> T, + wit_in: &impl Fn(WitnessId) -> T, // witin id + instance: &impl Fn(Instance) -> T, + constant: &impl Fn(E::BaseField) -> T, + challenge: &impl Fn(ChallengeId, usize, E, E) -> T, + sum: &impl Fn( + T, + T, + &mut SimpleVecPool, PF1>, + &mut SimpleVecPool, PF2>, + ) -> T, + product: &impl Fn( + T, + T, + &mut SimpleVecPool, PF1>, + &mut SimpleVecPool, PF2>, + ) -> T, + scaled: &impl Fn( + T, + T, + T, + &mut SimpleVecPool, PF1>, + &mut SimpleVecPool, PF2>, + ) -> T, + pool_e: &mut SimpleVecPool, PF1>, + pool_b: &mut SimpleVecPool, PF2>, + ) -> T { + match self { + Expression::Fixed(f) => fixed_in(f), + Expression::WitIn(witness_id) => wit_in(*witness_id), + Expression::Instance(i) => instance(*i), + Expression::Constant(scalar) => constant(*scalar), + Expression::Sum(a, b) => { + let a = a.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + let b = b.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + sum(a, b, pool_e, pool_b) + } + Expression::Product(a, b) => { + let a = a.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + let b = b.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + product(a, b, pool_e, pool_b) + } + Expression::ScaledSum(x, a, b) => { + let x = x.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + let a = a.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + let b = b.evaluate_with_instance_pool( + fixed_in, wit_in, instance, constant, challenge, sum, product, scaled, pool_e, + pool_b, + ); + scaled(x, a, b, pool_e, pool_b) + } + Expression::Challenge(challenge_id, pow, scalar, offset) => { + challenge(*challenge_id, *pow, *scalar, *offset) + } + } + } + pub fn is_monomial_form(&self) -> bool { Self::is_monomial_form_inner(MonomialState::SumTerm, self) } diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 945404ff3..9aae503f5 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -3,6 +3,7 @@ #![feature(stmt_expr_attributes)] #![feature(variant_count)] #![feature(strict_overflow_ops)] +#![feature(sync_unsafe_cell)] pub mod error; pub mod instructions; diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2c8cae8bc..f728cd5ad 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -9,11 +9,13 @@ use itertools::{Itertools, enumerate, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension}, - util::ceil_log2, + util::{ceil_log2, max_usable_threads}, virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, }; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, +}; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProverMessage, IOPProverStateV2}, @@ -25,16 +27,18 @@ use crate::{ error::ZKVMError, expression::Instance, scheme::{ - constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, + constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, MIN_PAR_SIZE, NUM_FANIN, NUM_FANIN_LOGUP}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, - wit_infer_by_expr, + wit_infer_by_expr, wit_infer_by_expr_pool, }, }, structs::{ Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, - utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, + utils::{ + SimpleVecPool, get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads, + }, virtual_polys::VirtualPolynomials, }; @@ -238,14 +242,39 @@ impl> ZKVMProver { let wit_inference_span = entered_span!("wit_inference", profiling_3 = true); // main constraint: read/write record witness inference let record_span = entered_span!("record"); + let len = witnesses[0].evaluations().len(); + let mut pool_e: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::ZERO) + .collect::>() + }); + let mut pool_b: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::BaseField::ZERO) + .collect::>() + }); + let n_threads = max_usable_threads(); let records_wit: Vec> = cs .r_expressions - .par_iter() - .chain(cs.w_expressions.par_iter()) - .chain(cs.lk_expressions.par_iter()) + .iter() + .chain(cs.w_expressions.iter()) + .chain(cs.lk_expressions.iter()) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&[], &witnesses, pi, challenges, expr) + wit_infer_by_expr_pool( + &[], + &witnesses, + pi, + challenges, + expr, + n_threads, + &mut pool_e, + &mut pool_b, + ) }) .collect(); let (r_records_wit, w_lk_records_wit) = records_wit.split_at(cs.r_expressions.len()); @@ -701,20 +730,41 @@ impl> ZKVMProver { let wit_inference_span = entered_span!("wit_inference"); // main constraint: lookup denominator and numerator record witness inference let record_span = entered_span!("record"); + let len = witnesses[0].evaluations().len(); + let mut pool_e: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::ZERO) + .collect::>() + }); + let mut pool_b: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::BaseField::ZERO) + .collect::>() + }); + let n_threads = max_usable_threads(); let mut records_wit: Vec> = cs .r_table_expressions - .par_iter() + .iter() .map(|r| &r.expr) - .chain(cs.w_table_expressions.par_iter().map(|w| &w.expr)) - .chain( - cs.lk_table_expressions - .par_iter() - .map(|lk| &lk.multiplicity), - ) - .chain(cs.lk_table_expressions.par_iter().map(|lk| &lk.values)) + .chain(cs.w_table_expressions.iter().map(|w| &w.expr)) + .chain(cs.lk_table_expressions.iter().map(|lk| &lk.multiplicity)) + .chain(cs.lk_table_expressions.iter().map(|lk| &lk.values)) .map(|expr| { assert_eq!(expr.degree(), 1); - wit_infer_by_expr(&fixed, &witnesses, pi, challenges, expr) + wit_infer_by_expr_pool( + &fixed, + &witnesses, + pi, + challenges, + expr, + n_threads, + &mut pool_e, + &mut pool_b, + ) }) .collect(); let max_log2_num_instance = records_wit.iter().map(|mle| mle.num_vars()).max().unwrap(); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index c8ec6453a..af3ed8f8d 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{borrow::Cow, cell::SyncUnsafeCell, sync::Arc}; use ark_std::iterable::Iterable; use ff_ext::ExtensionField; @@ -7,9 +7,12 @@ use multilinear_extensions::{ commutative_op_mle_pair, mle::{DenseMultilinearExtension, FieldType, IntoMLE}, op_mle_xa_b, op_mle3_range, - util::ceil_log2, + util::{ceil_log2, max_usable_threads}, virtual_poly_v2::ArcMultilinearExtension, }; + +use ff::Field; + use rayon::{ iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, @@ -19,7 +22,9 @@ use rayon::{ }; use crate::{ - expression::Expression, scheme::constants::MIN_PAR_SIZE, utils::next_pow2_instance_padding, + expression::Expression, + scheme::constants::MIN_PAR_SIZE, + utils::{SimpleVecPool, next_pow2_instance_padding}, }; /// interleaving multiple mles into mles, and num_limbs indicate number of final limbs vector @@ -233,6 +238,28 @@ pub(crate) fn infer_tower_product_witness( wit_layers } +fn optional_arcpoly_unwrap_pushback( + poly: Cow>, + pool_e: &mut SimpleVecPool, impl Fn() -> Vec>, + pool_b: &mut SimpleVecPool, impl Fn() -> Vec>, + pool_expected_size_vec: usize, +) { + let len = poly.evaluations().len(); + if len == pool_expected_size_vec { + match poly { + Cow::Borrowed(_) => (), + Cow::Owned(_) => { + let poly = poly.into_owned(); + match poly.arc_try_unwrap().unwrap() { + FieldType::Base(vec) => pool_b.return_to_pool(vec), + FieldType::Ext(vec) => pool_e.return_to_pool(vec), + _ => unreachable!(), + }; + } + }; + } +} + pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( fixed: &[ArcMultilinearExtension<'a, E>], witnesses: &[ArcMultilinearExtension<'a, E>], @@ -240,113 +267,227 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( challenges: &[E; N], expr: &Expression, ) -> ArcMultilinearExtension<'a, E> { - expr.evaluate_with_instance::>( - &|f| fixed[f.0].clone(), - &|witness_id| witnesses[witness_id as usize].clone(), - &|i| instance[i.0].clone(), - &|scalar| { - let scalar: ArcMultilinearExtension = - Arc::new(DenseMultilinearExtension::from_evaluations_vec(0, vec![ - scalar, - ])); - scalar - }, - &|challenge_id, pow, scalar, offset| { - // TODO cache challenge power to be acquired once for each power - let challenge = challenges[challenge_id as usize]; - let challenge: ArcMultilinearExtension = Arc::new( - DenseMultilinearExtension::from_evaluations_ext_vec(0, vec![ - challenge.pow([pow as u64]) * scalar + offset, - ]), - ); - challenge - }, - &|a, b| { - commutative_op_mle_pair!(|a, b| { - match (a.len(), b.len()) { - (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] + b[0]], - )), - (1, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(b.len()), - b.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|b| a[0] + *b) - .collect(), - )), - (_, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|a| *a + b[0]) - .collect(), - )), - (_, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .zip(b.par_iter()) - .with_min_len(MIN_PAR_SIZE) - .map(|(a, b)| *a + b) - .collect(), - )), - } - }) - }, - &|a, b| { - commutative_op_mle_pair!(|a, b| { - match (a.len(), b.len()) { - (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] * b[0]], - )), - (1, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(b.len()), - b.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|b| a[0] * *b) - .collect(), - )), - (_, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|a| *a * b[0]) - .collect(), - )), - (_, _) => { - assert_eq!(a.len(), b.len()); - // we do the pointwise evaluation multiplication here without involving FFT - // the evaluations outside of range will be checked via sumcheck + identity polynomial - Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .zip(b.par_iter()) - .with_min_len(MIN_PAR_SIZE) - .map(|(a, b)| *a * b) - .collect(), - )) - } - } - }) - }, - &|x, a, b| { - op_mle_xa_b!(|x, a, b| { - assert_eq!(a.len(), 1); - assert_eq!(b.len(), 1); - let (a, b) = (a[0], b[0]); - Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(x.len()), - x.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|x| a * x + b) - .collect(), - )) - }) - }, + let n_threads = max_usable_threads(); + let len = witnesses[0].evaluations().len(); + let mut pool_e: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::ZERO) + .collect::>() + }); + let mut pool_b: SimpleVecPool, _> = SimpleVecPool::new(|| { + (0..len) + .into_par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|_| E::BaseField::ZERO) + .collect::>() + }); + wit_infer_by_expr_pool( + fixed, + witnesses, + instance, + challenges, + expr, + n_threads, + &mut pool_e, + &mut pool_b, ) } +#[allow(clippy::too_many_arguments)] +pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>( + fixed: &[ArcMultilinearExtension<'a, E>], + witnesses: &[ArcMultilinearExtension<'a, E>], + instance: &[ArcMultilinearExtension<'a, E>], + challenges: &[E; N], + expr: &Expression, + n_threads: usize, + pool_e: &mut SimpleVecPool, impl Fn() -> Vec>, + pool_b: &mut SimpleVecPool, impl Fn() -> Vec>, +) -> ArcMultilinearExtension<'a, E> { + let len = witnesses[0].evaluations().len(); + let poly = + expr.evaluate_with_instance_pool::>, _, _>( + &|f| Cow::Borrowed(&fixed[f.0]), + &|witness_id| Cow::Borrowed(&witnesses[witness_id as usize]), + &|i| Cow::Borrowed(&instance[i.0]), + &|scalar| { + let scalar: ArcMultilinearExtension = + Arc::new(DenseMultilinearExtension::from_evaluations_vec(0, vec![ + scalar, + ])); + Cow::Owned(scalar) + }, + &|challenge_id, pow, scalar, offset| { + // TODO cache challenge power to be acquired once for each power + let challenge = challenges[challenge_id as usize]; + let challenge: ArcMultilinearExtension = Arc::new( + DenseMultilinearExtension::from_evaluations_ext_vec(0, vec![ + challenge.pow([pow as u64]) * scalar + offset, + ]), + ); + Cow::Owned(challenge) + }, + &|cow_a, cow_b, pool_e, pool_b| { + let (a, b) = (cow_a.as_ref(), cow_b.as_ref()); + let poly = + commutative_op_mle_pair!( + |a, b, res| { + match (a.len(), b.len()) { + (1, 1) => { + let poly: ArcMultilinearExtension<_> = Arc::new( + DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + vec![a[0] + b[0]], + ), + ); + Cow::Owned(poly) + } + (1, _) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..b.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[0] + b[i]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + (_, 1) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[i] + b[0]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + (_, _) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[i] + b[i]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + } + }, + pool_e, + pool_b + ); + optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len); + poly + }, + &|cow_a, cow_b, pool_e, pool_b| { + let (a, b) = (cow_a.as_ref(), cow_b.as_ref()); + let poly = + commutative_op_mle_pair!( + |a, b, res| { + match (a.len(), b.len()) { + (1, 1) => { + let poly: ArcMultilinearExtension<_> = Arc::new( + DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + vec![a[0] * b[0]], + ), + ); + Cow::Owned(poly) + } + (1, _) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..b.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[0] * b[i]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + (_, 1) => { + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[i] * b[0]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + (_, _) => { + assert_eq!(a.len(), b.len()); + // we do the pointwise evaluation multiplication here without involving FFT + // the evaluations outside of range will be checked via sumcheck + identity polynomial + let res = SyncUnsafeCell::new(res); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..a.len()).skip(thread_id).step_by(n_threads).for_each( + |i| { + *ptr.add(i) = a[i] * b[i]; + }, + ) + }); + Cow::Owned(res.into_inner().into_mle().into()) + } + } + }, + pool_e, + pool_b + ); + optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len); + poly + }, + &|cow_x, cow_a, cow_b, pool_e, pool_b| { + let (x, a, b) = (cow_x.as_ref(), cow_a.as_ref(), cow_b.as_ref()); + let poly = op_mle_xa_b!( + |x, a, b, res| { + let res = SyncUnsafeCell::new(res); + assert_eq!(a.len(), 1); + assert_eq!(b.len(), 1); + let (a, b) = (a[0], b[0]); + (0..n_threads).into_par_iter().for_each(|thread_id| unsafe { + let ptr = (*res.get()).as_mut_ptr(); + (0..x.len()) + .skip(thread_id) + .step_by(n_threads) + .for_each(|i| { + *ptr.add(i) = a * x[i] + b; + }) + }); + Cow::Owned(res.into_inner().into_mle().into()) + }, + pool_e, + pool_b + ); + optional_arcpoly_unwrap_pushback(cow_a, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_b, pool_e, pool_b, len); + optional_arcpoly_unwrap_pushback(cow_x, pool_e, pool_b, len); + poly + }, + pool_e, + pool_b, + ); + match poly { + Cow::Borrowed(poly) => poly.clone(), + Cow::Owned(_) => poly.into_owned(), + } +} + pub(crate) fn eval_by_expr( witnesses: &[E], challenges: &[E], @@ -416,11 +557,10 @@ mod tests { expression::{Expression, ToExpr}, scheme::utils::{ infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, + wit_infer_by_expr, }, }; - use super::wit_infer_by_expr; - #[test] fn test_infer_tower_witness() { type E = GoldilocksExt2; diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 8b7d8cbde..79fa3c490 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, VecDeque}, fmt::Display, hash::Hash, panic::{self, PanicHookInfo}, @@ -229,3 +229,30 @@ where result } + +/// a simple vector pool +/// not support multi-thread access +pub struct SimpleVecPool T> { + pool: VecDeque, + factory_fn: F, +} + +impl T> SimpleVecPool { + // new pool with a factory closure + pub fn new(init: F) -> Self { + SimpleVecPool { + pool: VecDeque::new(), + factory_fn: init, + } + } + + // borrow an item from the pool, or create a new one if empty + pub fn borrow(&mut self) -> T { + self.pool.pop_front().unwrap_or_else(|| (self.factory_fn)()) + } + + // push an item to the pool + pub fn return_to_pool(&mut self, item: T) { + self.pool.push_back(item); + } +} diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index b4e8df983..85947721d 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -62,6 +62,8 @@ pub trait MultilinearExtension: Send + Sync { _ => panic!("evaluation not in base field"), } } + + fn arc_try_unwrap(self: Arc) -> Option>; } impl Debug for dyn MultilinearExtension> { @@ -821,6 +823,12 @@ impl MultilinearExtension for DenseMultilinearExtension FieldType::Unreachable => unreachable!(), } } + + fn arc_try_unwrap(self: Arc) -> Option> { + Arc::try_unwrap(self) + .ok() + .map(|it| it.evaluations_to_owned()) + } } pub struct RangedMultilinearExtension<'a, E: ExtensionField> { @@ -992,6 +1000,10 @@ impl<'a, E: ExtensionField> MultilinearExtension for RangedMultilinearExtensi fn dup(&self, _num_instances: usize, _num_dups: usize) -> DenseMultilinearExtension { unimplemented!() } + + fn arc_try_unwrap(self: Arc) -> Option> { + unimplemented!() + } } #[macro_export] @@ -1050,33 +1062,91 @@ macro_rules! op_mle3_range { let $bb_out = $op; $op_bb_out }}; + + ($x:ident, $a:ident, $b:ident, $res:ident, $x_vec:ident, $a_vec:ident, $b_vec:ident, $res_vec:ident, $op:expr, |$bb_out:ident| $op_bb_out:expr) => {{ + let $x = if let Some((start, offset)) = $x.evaluations_range() { + &$x_vec[start..][..offset] + } else { + &$x_vec[..] + }; + let $a = if let Some((start, offset)) = $a.evaluations_range() { + &$a_vec[start..][..offset] + } else { + &$a_vec[..] + }; + let $b = if let Some((start, offset)) = $b.evaluations_range() { + &$b_vec[start..][..offset] + } else { + &$b_vec[..] + }; + let $res = $res_vec; + assert_eq!($res.len(), $x.len()); + let $bb_out = $op; + $op_bb_out + }}; } /// deal with x * a + b #[macro_export] macro_rules! op_mle_xa_b { - (|$x:ident, $a:ident, $b:ident| $op:expr, |$bb_out:ident| $op_bb_out:expr) => { + (|$x:ident, $a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => { match (&$x.evaluations(), &$a.evaluations(), &$b.evaluations()) { ( $crate::mle::FieldType::Base(x_vec), $crate::mle::FieldType::Base(a_vec), $crate::mle::FieldType::Base(b_vec), ) => { - op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + let res_vec = $pool_b.borrow(); + op_mle3_range!( + $x, + $a, + $b, + $res, + x_vec, + a_vec, + b_vec, + res_vec, + $op, + |$bb_out| { $op_bb_out } + ) } ( $crate::mle::FieldType::Base(x_vec), $crate::mle::FieldType::Ext(a_vec), $crate::mle::FieldType::Base(b_vec), ) => { - op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + let res_vec = $pool_e.borrow(); + op_mle3_range!( + $x, + $a, + $b, + $res, + x_vec, + a_vec, + b_vec, + res_vec, + $op, + |$bb_out| { $op_bb_out } + ) } ( $crate::mle::FieldType::Base(x_vec), $crate::mle::FieldType::Ext(a_vec), $crate::mle::FieldType::Ext(b_vec), ) => { - op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + let res_vec = $pool_e.borrow(); + op_mle3_range!( + $x, + $a, + $b, + $res, + x_vec, + a_vec, + b_vec, + res_vec, + $op, + |$bb_out| { $op_bb_out } + ) } (x, a, b) => unreachable!( "unmatched pattern {:?} {:?} {:?}", @@ -1086,8 +1156,8 @@ macro_rules! op_mle_xa_b { ), } }; - (|$x:ident, $a:ident, $b:ident| $op:expr) => { - op_mle_xa_b!(|$x, $a, $b| $op, |out| out) + (|$x:ident, $a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident) => { + op_mle_xa_b!(|$x, $a, $b, $res| $op, $pool_e, $pool_b, |out| out) }; } @@ -1224,6 +1294,76 @@ macro_rules! commutative_op_mle_pair { _ => unreachable!(), } }; + (|$first:ident, $second:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident, |$bb_out:ident| $op_bb_out:expr) => { + match (&$first.evaluations(), &$second.evaluations()) { + ($crate::mle::FieldType::Base(base1), $crate::mle::FieldType::Base(base2)) => { + let $first = if let Some((start, offset)) = $first.evaluations_range() { + &base1[start..][..offset] + } else { + &base1[..] + }; + let $second = if let Some((start, offset)) = $second.evaluations_range() { + &base2[start..][..offset] + } else { + &base2[..] + }; + let $res = $pool_b.borrow(); + let $bb_out = $op; + $op_bb_out + } + ($crate::mle::FieldType::Ext(ext), $crate::mle::FieldType::Base(base)) => { + let $first = if let Some((start, offset)) = $first.evaluations_range() { + &ext[start..][..offset] + } else { + &ext[..] + }; + let $second = if let Some((start, offset)) = $second.evaluations_range() { + &base[start..][..offset] + } else { + &base[..] + }; + let $res = $pool_e.borrow(); + $op + } + ($crate::mle::FieldType::Base(base), $crate::mle::FieldType::Ext(ext)) => { + let base = if let Some((start, offset)) = $first.evaluations_range() { + &base[start..][..offset] + } else { + &base[..] + }; + let ext = if let Some((start, offset)) = $second.evaluations_range() { + &ext[start..][..offset] + } else { + &ext[..] + }; + // swap first and second to make ext field come first before base field. + // so the same coding template can apply. + // that's why first and second operand must be commutative + let $first = ext; + let $second = base; + let $res = $pool_e.borrow(); + $op + } + ($crate::mle::FieldType::Ext(ext), $crate::mle::FieldType::Ext(base)) => { + let $first = if let Some((start, offset)) = $first.evaluations_range() { + &ext[start..][..offset] + } else { + &ext[..] + }; + let $second = if let Some((start, offset)) = $second.evaluations_range() { + &base[start..][..offset] + } else { + &base[..] + }; + let $res = $pool_e.borrow(); + $op + } + _ => unreachable!(), + } + }; + (|$a:ident, $b:ident, $res:ident| $op:expr, $pool_e:ident, $pool_b:ident) => { + commutative_op_mle_pair!(|$a, $b, $res| $op, $pool_e, $pool_b, |out| out) + }; (|$a:ident, $b:ident| $op:expr) => { commutative_op_mle_pair!(|$a, $b| $op, |out| out) };