Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Eval Framework Copy requirement to Clone. #834

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ impl<E: EvalAtRow> LogupAtRow<E> {

pub fn write_frac(&mut self, eval: &mut E, fraction: Fraction<E::EF, E::EF>) {
// Add a constraint that num / denom = diff.
if let Some(cur_frac) = self.cur_frac {
let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0];
let diff = cur_cumsum - self.prev_col_cumsum;
if let Some(cur_frac) = self.cur_frac.clone() {
let [cur_cumsum] = eval.next_extension_interaction_mask(self.interaction, [0]);
let diff = cur_cumsum.clone() - self.prev_col_cumsum.clone();
self.prev_col_cumsum = cur_cumsum;
eval.add_constraint(diff * cur_frac.denominator - cur_frac.numerator);
}
Expand All @@ -76,7 +76,7 @@ impl<E: EvalAtRow> LogupAtRow<E> {
pub fn finalize(mut self, eval: &mut E) {
assert!(!self.is_finalized, "LogupAtRow was already finalized");

let frac = self.cur_frac.unwrap();
let frac = self.cur_frac.clone().unwrap();

// TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted offset
// from the is_first column when constant columns are supported.
Expand All @@ -89,7 +89,7 @@ impl<E: EvalAtRow> LogupAtRow<E> {
);

// Constrain that the claimed_sum in case that it is not equal to the total_sum.
eval.add_constraint((claimed_cumsum - claimed_sum) * self.is_first);
eval.add_constraint((claimed_cumsum - claimed_sum) * self.is_first.clone());
(cur_cumsum, prev_row_cumsum)
}
None => {
Expand All @@ -99,8 +99,8 @@ impl<E: EvalAtRow> LogupAtRow<E> {
}
};
// Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row.
let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first * self.total_sum;
let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum;
let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first.clone() * self.total_sum;
let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum.clone();

eval.add_constraint(diff * frac.denominator - frac.numerator);

Expand Down Expand Up @@ -138,9 +138,9 @@ impl<const N: usize> LookupElements<N> {
alpha_powers,
}
}
pub fn combine<F: Copy, EF>(&self, values: &[F]) -> EF
pub fn combine<F: Clone, EF>(&self, values: &[F]) -> EF
where
EF: Copy + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
EF: Clone + Zero + From<F> + From<SecureField> + Mul<F, Output = EF> + Sub<EF, Output = EF>,
{
assert!(
self.alpha_powers.len() >= values.len(),
Expand All @@ -149,8 +149,8 @@ impl<const N: usize> LookupElements<N> {
values
.iter()
.zip(self.alpha_powers)
.fold(EF::zero(), |acc, (&value, power)| {
acc + EF::from(power) * value
.fold(EF::zero(), |acc, (value, power)| {
acc + EF::from(power) * value.clone()
})
- EF::from(self.z)
}
Expand Down
11 changes: 7 additions & 4 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub trait EvalAtRow {
/// constraints. It might be [BaseField] packed types, or even [SecureField], when evaluating
/// the columns out of domain.
type F: FieldExpOps
+ Copy
+ Clone
+ Debug
+ Zero
+ Neg<Output = Self::F>
Expand All @@ -48,7 +48,7 @@ pub trait EvalAtRow {
/// A field type representing the closure of `F` with multiplying by [SecureField]. Constraints
/// usually get multiplied by [SecureField] values for security.
type EF: One
+ Copy
+ Clone
+ Debug
+ Zero
+ From<Self::F>
Expand Down Expand Up @@ -84,8 +84,11 @@ pub trait EvalAtRow {
interaction: usize,
offsets: [isize; N],
) -> [Self::EF; N] {
let res_col_major = array::from_fn(|_| self.next_interaction_mask(interaction, offsets));
array::from_fn(|i| Self::combine_ef(res_col_major.map(|c| c[i])))
let mut res_col_major =
array::from_fn(|_| self.next_interaction_mask(interaction, offsets).into_iter());
array::from_fn(|_| {
Self::combine_ef(res_col_major.each_mut().map(|iter| iter.next().unwrap()))
})
}

/// Adds a constraint to the component.
Expand Down
8 changes: 4 additions & 4 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ impl SimdBackend {
);

let mut product = F::one();
for &num in mappings.iter() {
for num in mappings.iter() {
if index & 1 == 1 {
product *= num;
product *= *num;
}
index >>= 1;
if index == 0 {
Expand Down Expand Up @@ -108,8 +108,8 @@ impl SimdBackend {
.iter()
.skip(1)
.zip(denom_inverses.iter())
.for_each(|(&m, &d)| {
steps.push(m * d);
.for_each(|(m, d)| {
steps.push(*m * *d);
});
steps.push(F::one());
steps
Expand Down
48 changes: 35 additions & 13 deletions crates/prover/src/core/backend/simd/very_packed_m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ use crate::core::fields::FieldExpOps;
pub const LOG_N_VERY_PACKED_ELEMS: u32 = 1;
pub const N_VERY_PACKED_ELEMS: usize = 1 << LOG_N_VERY_PACKED_ELEMS;

#[derive(Copy, Clone, Debug)]
#[derive(Clone, Debug, Copy)]
#[repr(transparent)]
pub struct Vectorized<A, const N: usize>(pub [A; N]);
pub struct Vectorized<A: Copy, const N: usize>(pub [A; N]);

impl<A, const N: usize> Vectorized<A, N> {
impl<A: Copy, const N: usize> Vectorized<A, N> {
pub fn from_fn<F>(cb: F) -> Self
where
F: FnMut(usize) -> A,
Expand All @@ -27,17 +27,18 @@ impl<A, const N: usize> Vectorized<A, N> {
}
}

impl<A, const N: usize> From<[A; N]> for Vectorized<A, N> {
impl<A: Copy, const N: usize> From<[A; N]> for Vectorized<A, N> {
fn from(array: [A; N]) -> Self {
Vectorized(array)
}
}

unsafe impl<A, const N: usize> Zeroable for Vectorized<A, N> {
unsafe impl<A: Copy, const N: usize> Zeroable for Vectorized<A, N> {
fn zeroed() -> Self {
unsafe { core::mem::zeroed() }
}
}

unsafe impl<A: Pod, const N: usize> Pod for Vectorized<A, N> {}

pub type VeryPackedM31 = Vectorized<PackedM31, N_VERY_PACKED_ELEMS>;
Expand Down Expand Up @@ -121,47 +122,65 @@ impl Scalar for PackedM31 {}
impl Scalar for PackedCM31 {}
impl Scalar for PackedQM31 {}

impl<A: Add<B> + Copy, B: Copy, const N: usize> Add<Vectorized<B, N>> for Vectorized<A, N> {
impl<A: Add<B> + Copy, B: Copy, const N: usize> Add<Vectorized<B, N>> for Vectorized<A, N>
where
<A as Add<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn add(self, other: Vectorized<B, N>) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] + other.0[i])
}
}

impl<A: Add<B> + Copy, B: Scalar + Copy, const N: usize> Add<B> for Vectorized<A, N> {
impl<A: Add<B> + Copy, B: Scalar + Copy, const N: usize> Add<B> for Vectorized<A, N>
where
<A as Add<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn add(self, other: B) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] + other)
}
}

impl<A: Sub<B> + Copy, B: Copy, const N: usize> Sub<Vectorized<B, N>> for Vectorized<A, N> {
impl<A: Sub<B> + Copy, B: Copy, const N: usize> Sub<Vectorized<B, N>> for Vectorized<A, N>
where
<A as Sub<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn sub(self, other: Vectorized<B, N>) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] - other.0[i])
}
}

impl<A: Sub<B> + Copy, B: Scalar + Copy, const N: usize> Sub<B> for Vectorized<A, N> {
impl<A: Sub<B> + Copy, B: Scalar + Copy, const N: usize> Sub<B> for Vectorized<A, N>
where
<A as Sub<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn sub(self, other: B) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] - other)
}
}

impl<A: Mul<B> + Copy, B: Copy, const N: usize> Mul<Vectorized<B, N>> for Vectorized<A, N> {
impl<A: Mul<B> + Copy, B: Copy, const N: usize> Mul<Vectorized<B, N>> for Vectorized<A, N>
where
<A as Mul<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn mul(self, other: Vectorized<B, N>) -> Self::Output {
Vectorized::from_fn(|i| self.0[i] * other.0[i])
}
}

impl<A: Mul<B> + Copy, B: Scalar + Copy, const N: usize> Mul<B> for Vectorized<A, N> {
impl<A: Mul<B> + Copy, B: Scalar + Copy, const N: usize> Mul<B> for Vectorized<A, N>
where
<A as Mul<B>>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

fn mul(self, other: B) -> Self::Output {
Expand Down Expand Up @@ -197,7 +216,10 @@ impl<A: MulAssign<B> + Copy, B: Copy, const N: usize> MulAssign<Vectorized<B, N>
}
}

impl<A: Neg + Copy, const N: usize> Neg for Vectorized<A, N> {
impl<A: Neg + Copy, const N: usize> Neg for Vectorized<A, N>
where
<A as Neg>::Output: Copy,
{
type Output = Vectorized<A::Output, N>;

#[inline(always)]
Expand All @@ -222,7 +244,7 @@ impl<A: One + Copy, const N: usize> One for Vectorized<A, N> {
}
}

impl<A: FieldExpOps + Zero, const N: usize> FieldExpOps for Vectorized<A, N> {
impl<A: FieldExpOps + Zero + Copy, const N: usize> FieldExpOps for Vectorized<A, N> {
fn inverse(&self) -> Self {
Vectorized::from_fn(|i| {
assert!(!self.0[i].is_zero(), "0 has no inverse");
Expand Down
26 changes: 13 additions & 13 deletions crates/prover/src/core/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>
}

pub fn double(&self) -> Self {
*self + *self
self.clone() + self.clone()
}

/// Applies the circle's x-coordinate doubling map.
Expand All @@ -40,7 +40,7 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>
/// ```
pub fn double_x(x: F) -> F {
let sx = x.square();
sx + sx - F::one()
sx.clone() + sx - F::one()
}

/// Returns the log order of a point.
Expand All @@ -61,7 +61,7 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>
// we only need the x-coordinate to check order since the only point
// with x=1 is the circle's identity
let mut res = 0;
let mut cur = self.x;
let mut cur = self.x.clone();
while cur != F::one() {
cur = Self::double_x(cur);
res += 1;
Expand All @@ -71,10 +71,10 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>

pub fn mul(&self, mut scalar: u128) -> CirclePoint<F> {
let mut res = Self::zero();
let mut cur = *self;
let mut cur = self.clone();
while scalar > 0 {
if scalar & 1 == 1 {
res = res + cur;
res = res + cur.clone();
}
cur = cur.double();
scalar >>= 1;
Expand All @@ -83,7 +83,7 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>
}

pub fn repeated_double(&self, n: u32) -> Self {
let mut res = *self;
let mut res = self.clone();
for _ in 0..n {
res = res.double();
}
Expand All @@ -92,22 +92,22 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>

pub fn conjugate(&self) -> CirclePoint<F> {
Self {
x: self.x,
y: -self.y,
x: self.x.clone(),
y: -self.y.clone(),
}
}

pub fn antipode(&self) -> CirclePoint<F> {
Self {
x: -self.x,
y: -self.y,
x: -self.x.clone(),
y: -self.y.clone(),
}
}

pub fn into_ef<EF: From<F>>(&self) -> CirclePoint<EF> {
CirclePoint {
x: self.x.into(),
y: self.y.into(),
x: self.x.clone().into(),
y: self.y.clone().into(),
}
}

Expand All @@ -126,7 +126,7 @@ impl<F: Zero + Add<Output = F> + FieldExpOps + Sub<Output = F> + Neg<Output = F>
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
let x = self.x * rhs.x - self.y * rhs.y;
let x = self.x.clone() * rhs.x.clone() - self.y.clone() * rhs.y.clone();
let y = self.x * rhs.y + self.y * rhs.x;
Self { x, y }
}
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::fields::m31::BaseField;

pub fn butterfly<F>(v0: &mut F, v1: &mut F, twid: BaseField)
where
F: Copy + AddAssign<F> + Sub<F, Output = F> + Mul<BaseField, Output = F>,
F: AddAssign<F> + Sub<F, Output = F> + Mul<BaseField, Output = F> + Copy,
{
let tmp = *v1 * twid;
*v1 = *v0 - tmp;
Expand All @@ -13,7 +13,7 @@ where

pub fn ibutterfly<F>(v0: &mut F, v1: &mut F, itwid: BaseField)
where
F: Copy + AddAssign<F> + Add<F, Output = F> + Sub<F, Output = F> + Mul<BaseField, Output = F>,
F: AddAssign<F> + Add<F, Output = F> + Sub<F, Output = F> + Mul<BaseField, Output = F> + Copy,
{
let tmp = *v0;
*v0 = tmp + *v1;
Expand Down
12 changes: 6 additions & 6 deletions crates/prover/src/core/fields/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,12 @@ macro_rules! m31 {
/// assert_eq!(pow2147483645(v), v.pow(2147483645));
/// ```
pub fn pow2147483645<T: FieldExpOps>(v: T) -> T {
let t0 = sqn::<2, T>(v) * v;
let t1 = sqn::<1, T>(t0) * t0;
let t2 = sqn::<3, T>(t1) * t0;
let t3 = sqn::<1, T>(t2) * t0;
let t4 = sqn::<8, T>(t3) * t3;
let t5 = sqn::<8, T>(t4) * t3;
let t0 = sqn::<2, T>(v.clone()) * v.clone();
let t1 = sqn::<1, T>(t0.clone()) * t0.clone();
let t2 = sqn::<3, T>(t1.clone()) * t0.clone();
let t3 = sqn::<1, T>(t2.clone()) * t0.clone();
let t4 = sqn::<8, T>(t3.clone()) * t3.clone();
let t5 = sqn::<8, T>(t4.clone()) * t3.clone();
sqn::<7, T>(t5) * t2
}

Expand Down
Loading
Loading