Skip to content

Commit

Permalink
Cleaner API
Browse files Browse the repository at this point in the history
  • Loading branch information
Pratyush committed Jan 14, 2024
1 parent 963bab9 commit 9cc62a9
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 119 deletions.
118 changes: 0 additions & 118 deletions ec/src/scalar_mul/fixed_base.rs

This file was deleted.

132 changes: 131 additions & 1 deletion ec/src/scalar_mul/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
pub mod glv;
pub mod wnaf;

pub mod fixed_base;
pub mod variable_base;

use crate::short_weierstrass::{Affine, Projective, SWCurveConfig};
use crate::PrimeGroup;
use ark_ff::{AdditiveGroup, Zero};
use ark_std::{
cfg_iter, cfg_iter_mut,
ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign},
vec::Vec,
};
Expand Down Expand Up @@ -79,4 +79,134 @@ pub trait ScalarMul:
const NEGATION_IS_CHEAP: bool;

fn batch_convert_to_mul_base(bases: &[Self]) -> Vec<Self::MulBase>;

/// Compute the vector v[0].G, v[1].G, ..., v[n-1].G, given:
/// - an element `g`
/// - a list `v` of n scalars
///
/// # Example
/// ```
/// use ark_std::{One, UniformRand};
/// use ark_ec::pairing::Pairing;
/// use ark_test_curves::bls12_381::G1Projective as G;
/// use ark_test_curves::bls12_381::Fr;
/// use ark_ec::scalar_mul::fixed_base::FixedBase;
///
/// // Compute G, s.G, s^2.G, ..., s^9.G
/// let mut rng = ark_std::test_rng();
/// let max_degree = 10;
/// let s = Fr::rand(&mut rng);
/// let g = G::rand(&mut rng);
/// let mut powers_of_s = vec![Fr::one()];
/// let mut cur = s;
/// for _ in 0..max_degree {
/// powers_of_s.push(cur);
/// cur *= &s;
/// }
/// let powers_of_g: Vec<G> = g.batch_mul(&powers_of_s);
/// let naive_powers_of_g: Vec<G> = powers_of_s.iter().map(|e| g * e).collect();
/// assert_eq!(powers_of_g, naive_powers_of_g);
/// ```
fn batch_mul(self, v: &[Self::ScalarField]) -> Vec<Self> {
let table = BatchMulPreprocessing::new(self, v.len());
self.batch_mul_with_preprocessing(v, table)
}

fn batch_mul_with_preprocessing(
self: Self::MulBase,
v: &[Self::ScalarField],
preprocessing: &BatchMulPreprocessing<Self>,
) -> Vec<Self::MulBase> {
cfg_iter!(v).map(|e| preprocessing.windowed_mul(e)).collect()
}
}

/// Preprocessing used internally for batch scalar multiplication via [`ScalarMul::batch_mul`].
/// - `window` is the window size used for the precomputation
/// - `max_scalar_size` is the maximum size of the scalars that will be multiplied
/// - `table` is the precomputed table of multiples of `base`
pub struct BatchMulPreprocessing<T: ScalarMul> {
pub window: usize,
pub max_scalar_size: usize,
pub table: Vec<Vec<T::MulBase>>,
}

impl<T: ScalarMul> BatchMulPreprocessing<T> {
pub fn new(base: T::MulBase, num_scalars: usize) -> Self {
let window = Self::get_mul_window_size(num_scalars);
let scalar_size = T::ScalarField::MODULUS_BIT_SIZE as usize;
Self::with_window_and_scalar_size(base, window, scalar_size)
}

fn window_size(num_scalars: usize) -> usize {
if num_scalars < 32 {
3
} else {
ln_without_floats(num_scalars)
}
}

pub fn with_window_and_scalar_size(
base: T::MulBase,
window: usize,
max_scalar_size: usize,
) -> Self {
let in_window = 1 << window;
let outerc = (max_scalar_size + window - 1) / window;
let last_in_window = 1 << (max_scalar_size - (outerc - 1) * window);

let mut multiples_of_g = vec![vec![T::zero(); in_window]; outerc];

let mut g_outer = base;
let mut g_outers = Vec::with_capacity(outerc);
for _ in 0..outerc {
g_outers.push(g_outer);
for _ in 0..window {
g_outer.double_in_place();
}
}
cfg_iter_mut!(multiples_of_g)
.enumerate()
.take(outerc)
.zip(g_outers)
.for_each(|((outer, multiples_of_g), g_outer)| {
let cur_in_window = if outer == outerc - 1 {
last_in_window
} else {
in_window
};

let mut g_inner = T::zero();
for inner in multiples_of_g.iter_mut().take(cur_in_window) {
*inner = g_inner;
g_inner += &g_outer;
}
});
let table = cfg_iter!(multiples_of_g)
.map(|s| T::batch_convert_to_mul_base(s))
.collect();
Self {
window,
max_scalar_size,
table,
}
}

fn windowed_mul(&self, scalar: &T::ScalarField) -> T {
let outerc = (self.max_scalar_size + self.window - 1) / self.window;
let modulus_size = T::ScalarField::MODULUS_BIT_SIZE as usize;
let scalar_val = scalar.into_bigint().to_bits_le();

let mut res = T::from(self.table[0][0]);
for outer in 0..outerc {
let mut inner = 0usize;
for i in 0..self.window {
if outer * self.window + i < modulus_size && scalar_val[outer * self.window + i] {
inner |= 1 << i;
}
}
res += &self.table[outer][inner];
}
res
}
}

0 comments on commit 9cc62a9

Please sign in to comment.